Commit 19173ab7 authored by Chao Liu's avatar Chao Liu
Browse files

add G

parent c3379310
...@@ -25,7 +25,7 @@ void print_helper_msg() ...@@ -25,7 +25,7 @@ void print_helper_msg()
<< "arg3: time kernel (0=no, 1=yes)\n" << "arg3: time kernel (0=no, 1=yes)\n"
<< "arg4: N spatial dimensions (default 2)\n" << "arg4: N spatial dimensions (default 2)\n"
<< "Following arguments (depending on number of spatial dims):\n" << "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n" << " G, N, K, C, \n"
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n" << " <filter spatial dimensions>, (ie Y, X for 2D)\n"
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n" << " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
<< " <strides>, (ie Sy, Sx for 2D)\n" << " <strides>, (ie Sy, Sx for 2D)\n"
...@@ -37,6 +37,7 @@ void print_helper_msg() ...@@ -37,6 +37,7 @@ void print_helper_msg()
ck::utils::conv::ConvParam parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[]) ck::utils::conv::ConvParam parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[])
{ {
const ck::index_t G = std::stoi(argv[arg_idx++]);
const ck::index_t N = std::stoi(argv[arg_idx++]); const ck::index_t N = std::stoi(argv[arg_idx++]);
const ck::index_t K = std::stoi(argv[arg_idx++]); const ck::index_t K = std::stoi(argv[arg_idx++]);
const ck::index_t C = std::stoi(argv[arg_idx++]); const ck::index_t C = std::stoi(argv[arg_idx++]);
...@@ -79,6 +80,7 @@ ck::utils::conv::ConvParam parse_conv_params(int num_dim_spatial, int arg_idx, c ...@@ -79,6 +80,7 @@ ck::utils::conv::ConvParam parse_conv_params(int num_dim_spatial, int arg_idx, c
} }
return ck::utils::conv::ConvParam{num_dim_spatial, return ck::utils::conv::ConvParam{num_dim_spatial,
G,
N, N,
K, K,
C, C,
...@@ -110,23 +112,56 @@ int run_conv_fwd(bool do_verification, ...@@ -110,23 +112,56 @@ int run_conv_fwd(bool do_verification,
const WeiElementOp& wei_element_op, const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op) const OutElementOp& out_element_op)
{ {
const auto in_desc = ck::utils::conv::get_input_host_tensor_descriptor<InLayout>(conv_param); #if 0
const auto wei_desc = ck::utils::conv::get_weight_host_tensor_descriptor<WeiLayout>(conv_param); const auto in_g_n_c_wis_desc = ck::utils::conv::get_input_host_tensor_descriptor<InLayout>(conv_param);
const auto out_desc = ck::utils::conv::get_output_host_tensor_descriptor<OutLayout>(conv_param); const auto wei_g_k_c_xs_desc = ck::utils::conv::get_weight_host_tensor_descriptor<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc = ck::utils::conv::get_output_host_tensor_descriptor<OutLayout>(conv_param);
// hacky, hardcoded for 2d NHWK #else
const auto bias_desc = HostTensorDescriptor( const auto in_g_n_wis_c_desc = HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(conv_param.N_), std::vector<std::size_t>{static_cast<std::size_t>(conv_param.G_),
static_cast<std::size_t>(conv_param.N_),
static_cast<std::size_t>(conv_param.input_spatial_lengths_[0]),
static_cast<std::size_t>(conv_param.input_spatial_lengths_[1]),
static_cast<std::size_t>(conv_param.C_)});
const auto wei_g_k_xs_c_desc = HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(conv_param.G_),
static_cast<std::size_t>(conv_param.K_),
static_cast<std::size_t>(conv_param.filter_spatial_lengths_[0]),
static_cast<std::size_t>(conv_param.filter_spatial_lengths_[1]),
static_cast<std::size_t>(conv_param.C_)});
const auto bias_g_n_wos_k_desc = HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(conv_param.G_),
static_cast<std::size_t>(conv_param.N_),
static_cast<std::size_t>(conv_param.output_spatial_lengths_[0]), static_cast<std::size_t>(conv_param.output_spatial_lengths_[0]),
static_cast<std::size_t>(conv_param.output_spatial_lengths_[1]), static_cast<std::size_t>(conv_param.output_spatial_lengths_[1]),
static_cast<std::size_t>(conv_param.K_)}, static_cast<std::size_t>(conv_param.K_)},
std::vector<std::size_t>{0, 0, 0, 1}); std::vector<std::size_t>{0, 0, 0, 0, 1});
const auto out_g_n_wos_k_desc = HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(conv_param.G_),
static_cast<std::size_t>(conv_param.N_),
static_cast<std::size_t>(conv_param.output_spatial_lengths_[0]),
static_cast<std::size_t>(conv_param.output_spatial_lengths_[1]),
static_cast<std::size_t>(conv_param.K_)});
Tensor<InDataType> in(in_desc); // tensor descriptor in NCHW/KXYC/NKHW dimensional order
Tensor<WeiDataType> wei(wei_desc); const auto in_g_n_c_wis_desc = transpose_host_tensor_descriptor_given_new2old(
Tensor<OutDataType> bias(bias_desc); in_g_n_wis_c_desc, std::vector<ck::index_t>{0, 1, 4, 2, 3});
Tensor<OutDataType> out_host(out_desc); const auto wei_g_k_c_xs_desc = transpose_host_tensor_descriptor_given_new2old(
Tensor<OutDataType> out_device(out_desc); wei_g_k_xs_c_desc, std::vector<ck::index_t>{0, 1, 4, 2, 3});
const auto bias_g_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
bias_g_n_wos_k_desc, std::vector<ck::index_t>{0, 1, 4, 2, 3});
const auto out_g_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_g_n_wos_k_desc, std::vector<ck::index_t>{0, 1, 4, 2, 3});
#endif
Tensor<InDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
Tensor<OutDataType> bias(bias_g_n_k_wos_desc);
Tensor<OutDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutDataType> out_device(out_g_n_k_wos_desc);
std::cout << "in: " << in.mDesc << std::endl; std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl; std::cout << "wei: " << wei.mDesc << std::endl;
...@@ -156,80 +191,14 @@ int run_conv_fwd(bool do_verification, ...@@ -156,80 +191,14 @@ int run_conv_fwd(bool do_verification,
wei_device_buf.ToDevice(wei.mData.data()); wei_device_buf.ToDevice(wei.mData.data());
bias_device_buf.ToDevice(bias.mData.data()); bias_device_buf.ToDevice(bias.mData.data());
// tensor descriptor in NCHW/KXYC/NKHW dimensional order std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
HostTensorDescriptor in_n_c_wis_desc = in_desc; std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
HostTensorDescriptor wei_k_c_xs_desc = wei_desc; std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
HostTensorDescriptor bias_n_k_wos_desc = bias_desc; std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
HostTensorDescriptor out_n_k_wos_desc = out_desc; std::array<ck::index_t, NDimSpatial + 3> d_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> d_g_n_k_wos_strides{};
// input std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWC>) std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
{
in_n_c_wis_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NHWC>)
{
in_n_c_wis_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{
in_n_c_wis_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// weight
if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC>)
{
wei_k_c_xs_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC>)
{
wei_k_c_xs_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
wei_k_c_xs_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// output
if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NWK>)
{
out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 2, 1});
bias_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
bias_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK>)
{
out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 3, 1, 2});
bias_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
bias_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
bias_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
bias_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
std::array<ck::index_t, NDimSpatial + 2> a_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 2> a_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 2> b_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 2> b_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 2> d_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 2> d_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 2> e_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 2> e_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{}; std::array<ck::index_t, NDimSpatial> input_left_pads{};
...@@ -237,14 +206,14 @@ int run_conv_fwd(bool do_verification, ...@@ -237,14 +206,14 @@ int run_conv_fwd(bool do_verification,
auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); };
copy(in_n_c_wis_desc.GetLengths(), a_n_c_wis_lengths); copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
copy(in_n_c_wis_desc.GetStrides(), a_n_c_wis_strides); copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_k_c_xs_desc.GetLengths(), b_k_c_xs_lengths); copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_k_c_xs_desc.GetStrides(), b_k_c_xs_strides); copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(bias_n_k_wos_desc.GetLengths(), d_n_k_wos_lengths); copy(bias_g_n_k_wos_desc.GetLengths(), d_g_n_k_wos_lengths);
copy(bias_n_k_wos_desc.GetStrides(), d_n_k_wos_strides); copy(bias_g_n_k_wos_desc.GetStrides(), d_g_n_k_wos_strides);
copy(out_n_k_wos_desc.GetLengths(), e_n_k_wos_lengths); copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_n_k_wos_desc.GetStrides(), e_n_k_wos_strides); copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides); copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations); copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads); copy(conv_param.input_left_pads_, input_left_pads);
...@@ -258,14 +227,14 @@ int run_conv_fwd(bool do_verification, ...@@ -258,14 +227,14 @@ int run_conv_fwd(bool do_verification,
wei_device_buf.GetDeviceBuffer(), wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{bias_device_buf.GetDeviceBuffer()}, std::array<const void*, 1>{bias_device_buf.GetDeviceBuffer()},
out_device_buf.GetDeviceBuffer(), out_device_buf.GetDeviceBuffer(),
a_n_c_wis_lengths, a_g_n_c_wis_lengths,
a_n_c_wis_strides, a_g_n_c_wis_strides,
b_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_k_c_xs_strides, b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 2>, 1>{{d_n_k_wos_lengths}}, std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d_g_n_k_wos_lengths}},
std::array<std::array<ck::index_t, NDimSpatial + 2>, 1>{{d_n_k_wos_strides}}, std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d_g_n_k_wos_strides}},
e_n_k_wos_lengths, e_g_n_k_wos_lengths,
e_n_k_wos_strides, e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -295,7 +264,7 @@ int run_conv_fwd(bool do_verification, ...@@ -295,7 +264,7 @@ int run_conv_fwd(bool do_verification,
{ {
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
Tensor<OutDataType> c_host(out_desc); Tensor<OutDataType> c_host(out_g_n_k_wos_desc);
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial, auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InLayout, InLayout,
...@@ -322,16 +291,20 @@ int run_conv_fwd(bool do_verification, ...@@ -322,16 +291,20 @@ int run_conv_fwd(bool do_verification,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(int n = 0; n < out_host.mDesc.GetLengths()[0]; n++) for(int g = 0; g < out_host.mDesc.GetLengths()[0]; g++)
{ {
for(int ho = 0; ho < out_host.mDesc.GetLengths()[1]; ho++) for(int n = 0; n < out_host.mDesc.GetLengths()[1]; n++)
{ {
for(int wo = 0; wo < out_host.mDesc.GetLengths()[2]; wo++) for(int k = 0; k < out_host.mDesc.GetLengths()[2]; k++)
{ {
for(int k = 0; k < out_host.mDesc.GetLengths()[3]; k++) for(int ho = 0; ho < out_host.mDesc.GetLengths()[3]; ho++)
{ {
out_element_op( for(int wo = 0; wo < out_host.mDesc.GetLengths()[4]; wo++)
out_host(n, ho, wo, k), c_host(n, ho, wo, k), bias(n, ho, wo, k)); {
out_element_op(out_host(g, n, k, ho, wo),
c_host(g, n, k, ho, wo),
bias(g, n, k, ho, wo));
}
} }
} }
} }
......
...@@ -138,7 +138,7 @@ int main(int argc, char* argv[]) ...@@ -138,7 +138,7 @@ int main(int argc, char* argv[])
int num_dim_spatial = 2; int num_dim_spatial = 2;
ck::utils::conv::ConvParam params{ ck::utils::conv::ConvParam params{
2, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; 2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
if(argc == 1) if(argc == 1)
{ {
......
...@@ -39,14 +39,14 @@ struct DeviceConvFwdMultipleD : public BaseOperator ...@@ -39,14 +39,14 @@ struct DeviceConvFwdMultipleD : public BaseOperator
const void* p_b, const void* p_b,
const std::array<const void*, NumDTensor>& p_ds, const std::array<const void*, NumDTensor>& p_ds,
void* p_e, void* p_e,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_lengths, const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_strides, const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
......
...@@ -28,39 +28,152 @@ namespace device { ...@@ -28,39 +28,152 @@ namespace device {
namespace { namespace {
template <index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
Array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
Array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
return ds_offset;
}
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_;
};
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename ABDataType,
typename FloatDsPointer, typename DsPointer,
typename FloatE, typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap, typename Block2ETileMap,
typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_multiple_d_xdl_cshuffle(const FloatAB* __restrict__ p_a_grid, kernel_batch_gemm_multiple_d_xdl_cshuffle(
const FloatAB* __restrict__ p_b_grid, const ABDataType* __restrict__ p_a_grid,
FloatDsPointer p_ds_grid, const ABDataType* __restrict__ p_b_grid,
FloatE* __restrict__ p_e_grid, DsPointer p_ds_grid,
const AElementwiseOperation a_element_op, EDataType* __restrict__ p_e_grid,
const BElementwiseOperation b_element_op, const AElementwiseOperation a_element_op,
const CDEElementwiseOperation cde_element_op, const BElementwiseOperation b_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const CDEElementwiseOperation cde_element_op,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const index_t batch_count,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
ds_grid_desc_mblock_mperblock_nblock_nperblock, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map) const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if 1
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
DsPointer p_ds_grid_grp;
static constexpr index_t NumDTensor =
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock_,
block_2_ctile_map);
#else
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
...@@ -71,26 +184,31 @@ __global__ void ...@@ -71,26 +184,31 @@ __global__ void
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op, cde_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_k0_m_k1,
b_grid_desc_bk0_n_bk1, b_grid_desc_k0_n_k1,
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock_,
block_2_etile_map); block_2_ctile_map);
#endif
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_ds_grid; ignore = p_ds_grid;
ignore = p_e_grid; ignore = p_e_grid;
ignore = batch_count;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = cde_element_op; ignore = cde_element_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = compute_ptr_offset_of_batch;
ignore = b_grid_desc_bk0_n_bk1; ignore = block_2_ctile_map;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_etile_map;
#endif #endif
} }
} // namespace } // namespace
// //
...@@ -187,33 +305,33 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -187,33 +305,33 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
is_same_v<ALay, tensor_layout::convolution::NWC>, is_same_v<ALay, tensor_layout::convolution::NWC>,
bool>::type = false> bool>::type = false>
static auto static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths, MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads)
{ {
const index_t N = a_n_c_wis_lengths[0]; const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_n_c_wis_lengths[1]; const index_t C = a_g_n_c_wis_lengths[2];
const index_t GemmMRaw = N * std::accumulate(e_n_k_wos_lengths.begin() + 2, const index_t GemmMRaw = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_n_k_wos_lengths.begin() + 3, e_g_n_k_wos_lengths.begin() + 4,
index_t{1}, index_t{1},
std::multiplies<index_t>()); std::multiplies<index_t>());
const index_t GemmKRaw = C * std::accumulate(b_k_c_xs_lengths.begin() + 2, const index_t GemmKRaw = C * std::accumulate(b_g_k_c_xs_lengths.begin() + 3,
b_k_c_xs_lengths.begin() + 3, b_g_k_c_xs_lengths.begin() + 4,
index_t{1}, index_t{1},
std::multiplies<index_t>()); std::multiplies<index_t>());
const index_t Wi = a_n_c_wis_lengths[2]; const index_t Wi = a_g_n_c_wis_lengths[3];
const index_t Wo = e_n_k_wos_lengths[2]; const index_t Wo = e_g_n_k_wos_lengths[3];
const index_t ConvStrideW = conv_filter_strides[0]; const index_t ConvStrideW = conv_filter_strides[0];
...@@ -255,7 +373,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -255,7 +373,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
} }
else else
{ {
const index_t X = b_k_c_xs_lengths[2]; const index_t X = b_g_k_c_xs_lengths[3];
const index_t ConvDilationW = conv_filter_dilations[0]; const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0]; const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0]; const index_t InRightPadW = input_right_pads[0];
...@@ -299,35 +417,35 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -299,35 +417,35 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
is_same_v<ALay, tensor_layout::convolution::NHWC>, is_same_v<ALay, tensor_layout::convolution::NHWC>,
bool>::type = false> bool>::type = false>
static auto static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths, MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads)
{ {
const index_t N = a_n_c_wis_lengths[0]; const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_n_c_wis_lengths[1]; const index_t C = a_g_n_c_wis_lengths[2];
const index_t GemmMRaw = N * std::accumulate(e_n_k_wos_lengths.begin() + 2, const index_t GemmMRaw = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_n_k_wos_lengths.begin() + 4, e_g_n_k_wos_lengths.begin() + 5,
index_t{1}, index_t{1},
std::multiplies<index_t>()); std::multiplies<index_t>());
const index_t GemmKRaw = C * std::accumulate(b_k_c_xs_lengths.begin() + 2, const index_t GemmKRaw = C * std::accumulate(b_g_k_c_xs_lengths.begin() + 3,
b_k_c_xs_lengths.begin() + 4, b_g_k_c_xs_lengths.begin() + 5,
index_t{1}, index_t{1},
std::multiplies<index_t>()); std::multiplies<index_t>());
const index_t Hi = a_n_c_wis_lengths[2]; const index_t Hi = a_g_n_c_wis_lengths[3];
const index_t Wi = a_n_c_wis_lengths[3]; const index_t Wi = a_g_n_c_wis_lengths[4];
const index_t Ho = e_n_k_wos_lengths[2]; const index_t Ho = e_g_n_k_wos_lengths[3];
const index_t Wo = e_n_k_wos_lengths[3]; const index_t Wo = e_g_n_k_wos_lengths[4];
const index_t ConvStrideH = conv_filter_strides[0]; const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1]; const index_t ConvStrideW = conv_filter_strides[1];
...@@ -372,8 +490,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -372,8 +490,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
} }
else else
{ {
const index_t Y = b_k_c_xs_lengths[2]; const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_k_c_xs_lengths[3]; const index_t X = b_g_k_c_xs_lengths[4];
const index_t ConvDilationH = conv_filter_dilations[0]; const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1]; const index_t ConvDilationW = conv_filter_dilations[1];
...@@ -425,37 +543,37 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -425,37 +543,37 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
is_same_v<ALay, tensor_layout::convolution::NDHWC>, is_same_v<ALay, tensor_layout::convolution::NDHWC>,
bool>::type = false> bool>::type = false>
static auto static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths, MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads)
{ {
const index_t N = a_n_c_wis_lengths[0]; const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_n_c_wis_lengths[1]; const index_t C = a_g_n_c_wis_lengths[2];
const index_t GemmMRaw = N * std::accumulate(e_n_k_wos_lengths.begin() + 2, const index_t GemmMRaw = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_n_k_wos_lengths.begin() + 5, e_g_n_k_wos_lengths.begin() + 6,
index_t{1}, index_t{1},
std::multiplies<index_t>()); std::multiplies<index_t>());
const index_t GemmKRaw = C * std::accumulate(b_k_c_xs_lengths.begin() + 2, const index_t GemmKRaw = C * std::accumulate(b_g_k_c_xs_lengths.begin() + 3,
b_k_c_xs_lengths.begin() + 5, b_g_k_c_xs_lengths.begin() + 6,
index_t{1}, index_t{1},
std::multiplies<index_t>()); std::multiplies<index_t>());
const index_t Di = a_n_c_wis_lengths[2]; const index_t Di = a_g_n_c_wis_lengths[3];
const index_t Hi = a_n_c_wis_lengths[3]; const index_t Hi = a_g_n_c_wis_lengths[4];
const index_t Wi = a_n_c_wis_lengths[4]; const index_t Wi = a_g_n_c_wis_lengths[5];
const index_t Do = e_n_k_wos_lengths[2]; const index_t Do = e_g_n_k_wos_lengths[3];
const index_t Ho = e_n_k_wos_lengths[3]; const index_t Ho = e_g_n_k_wos_lengths[4];
const index_t Wo = e_n_k_wos_lengths[4]; const index_t Wo = e_g_n_k_wos_lengths[5];
const index_t ConvStrideD = conv_filter_strides[0]; const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1]; const index_t ConvStrideH = conv_filter_strides[1];
...@@ -504,9 +622,9 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -504,9 +622,9 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
} }
else else
{ {
const index_t Z = b_k_c_xs_lengths[2]; const index_t Z = b_g_k_c_xs_lengths[3];
const index_t Y = b_k_c_xs_lengths[3]; const index_t Y = b_g_k_c_xs_lengths[4];
const index_t X = b_k_c_xs_lengths[4]; const index_t X = b_g_k_c_xs_lengths[5];
const index_t ConvDilationD = conv_filter_dilations[0]; const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1]; const index_t ConvDilationH = conv_filter_dilations[1];
...@@ -571,16 +689,16 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -571,16 +689,16 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
is_same_v<BLay, tensor_layout::convolution::KZYXC>, is_same_v<BLay, tensor_layout::convolution::KZYXC>,
bool>::type = false> bool>::type = false>
static auto static auto
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths, MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides) const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{ {
const index_t K = b_k_c_xs_lengths[0]; const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_k_c_xs_lengths[1]; const index_t C = b_g_k_c_xs_lengths[2];
const index_t GemmNRaw = K; const index_t GemmNRaw = K;
const index_t GemmKRaw = C * std::accumulate(b_k_c_xs_lengths.begin() + 2, const index_t GemmKRaw = C * std::accumulate(b_g_k_c_xs_lengths.begin() + 3,
b_k_c_xs_lengths.begin() + 2 + NDimSpatial, b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial,
index_t{1}, index_t{1},
std::multiplies<index_t>()); std::multiplies<index_t>());
...@@ -599,14 +717,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -599,14 +717,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
is_same_v<ELay, tensor_layout::convolution::NDHWK>, is_same_v<ELay, tensor_layout::convolution::NDHWK>,
bool>::type = false> bool>::type = false>
static auto static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths, MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
const index_t N = e_n_k_wos_lengths[0]; const index_t N = e_g_n_k_wos_lengths[1];
const index_t K = e_n_k_wos_lengths[1]; const index_t K = e_g_n_k_wos_lengths[2];
const index_t GemmMRaw = N * std::accumulate(e_n_k_wos_lengths.begin() + 2, const index_t GemmMRaw = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_n_k_wos_lengths.begin() + 2 + NDimSpatial, e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1}, index_t{1},
std::multiplies<index_t>()); std::multiplies<index_t>());
...@@ -627,18 +745,18 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -627,18 +745,18 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
is_same_v<ELay, tensor_layout::convolution::NDHW_K>, is_same_v<ELay, tensor_layout::convolution::NDHW_K>,
bool>::type = false> bool>::type = false>
static auto static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths, MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
namespace ctc = ck::tensor_layout::convolution; namespace ctc = ck::tensor_layout::convolution;
const index_t N = e_n_k_wos_lengths[0]; const index_t N = e_g_n_k_wos_lengths[1];
const index_t K = e_n_k_wos_lengths[1]; const index_t K = e_g_n_k_wos_lengths[2];
const index_t WoStride = e_n_k_wos_strides[NDimSpatial + 1]; const index_t WoStride = e_g_n_k_wos_strides[NDimSpatial + 2];
const index_t GemmMRaw = N * std::accumulate(e_n_k_wos_lengths.begin() + 2, const index_t GemmMRaw = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_n_k_wos_lengths.begin() + 2 + NDimSpatial, e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1}, index_t{1},
std::multiplies<index_t>()); std::multiplies<index_t>());
...@@ -654,15 +772,15 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -654,15 +772,15 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
} }
static auto MakeDsGridDescriptor_M_N( static auto MakeDsGridDescriptor_M_N(
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_lengths, const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_strides) const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(ds_n_k_wos_lengths[i], return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(ds_g_n_k_wos_lengths[i],
ds_n_k_wos_strides[i]); ds_g_n_k_wos_strides[i]);
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
...@@ -731,26 +849,27 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -731,26 +849,27 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument( Argument(const void* p_a,
const void* p_a, const void* p_b,
const void* p_b, const std::array<const void*, NumDTensor>& p_ds,
const std::array<const void*, NumDTensor>& p_ds, void* p_e,
void* p_e, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides, const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_lengths, ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_strides, const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths, ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_right_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const AElementwiseOperation& a_element_op, const std::array<index_t, NDimSpatial>& input_right_pads,
const BElementwiseOperation& b_element_op, const AElementwiseOperation& a_element_op,
const CDEElementwiseOperation& cde_element_op) const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a)}, : p_a_grid_{static_cast<const ADataType*>(p_a)},
p_b_grid_{static_cast<const BDataType*>(p_b)}, p_b_grid_{static_cast<const BDataType*>(p_b)},
p_ds_grid_{}, // FIXME p_ds_grid_{}, // FIXME
...@@ -764,56 +883,72 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -764,56 +883,72 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{}, block_2_etile_map_{},
compute_ptr_offset_of_batch_{},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
a_n_c_wis_lengths_{a_n_c_wis_lengths}, a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_n_c_wis_strides_{a_n_c_wis_strides}, a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
b_k_c_xs_lengths_{b_k_c_xs_lengths}, b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_k_c_xs_strides_{b_k_c_xs_strides}, b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
ds_n_k_wos_lengths_{ds_n_k_wos_lengths}, ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
ds_n_k_wos_strides_{ds_n_k_wos_strides}, ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
e_n_k_wos_lengths_{e_n_k_wos_lengths}, e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_n_k_wos_strides_{e_n_k_wos_strides}, e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides}, conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations}, conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads} input_right_pads_{input_right_pads}
{ {
a_grid_desc_m_k_ = DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_n_c_wis_lengths, // A desc
a_n_c_wis_strides, a_grid_desc_m_k_ = DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
b_k_c_xs_lengths, a_g_n_c_wis_strides,
b_k_c_xs_strides, b_g_k_c_xs_lengths,
e_n_k_wos_lengths, b_g_k_c_xs_strides,
e_n_k_wos_strides, e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads);
// B Desc
b_grid_desc_n_k_ = b_grid_desc_n_k_ =
DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_k_c_xs_lengths, b_k_c_xs_strides); DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
e_grid_desc_m_n_ = // E Desc
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_n_k_wos_lengths, e_n_k_wos_strides); e_grid_desc_m_n_ = DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
// A Des
a_grid_desc_ak0_m_ak1_ = a_grid_desc_ak0_m_ak1_ =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_); GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_);
// B Desc
b_grid_desc_bk0_n_bk1_ = b_grid_desc_bk0_n_bk1_ =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_); GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_);
// Block-to-e-tile
block_2_etile_map_ = Block2ETileMap{e_grid_desc_m_n_}; block_2_etile_map_ = Block2ETileMap{e_grid_desc_m_n_};
// populate pointer and desc for Ds // A/B/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>; using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]); p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
// D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
// D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>( ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
ds_n_k_wos_lengths[i], ds_n_k_wos_strides[i]); ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]);
}); });
// populate desc for Ds/E // populate desc for Ds/E
...@@ -865,20 +1000,22 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -865,20 +1000,22 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// block-to-e-tile map // block-to-e-tile map
Block2ETileMap block_2_etile_map_; Block2ETileMap block_2_etile_map_;
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_;
// element-wise op // element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 2> a_n_c_wis_lengths_; std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 2> a_n_c_wis_strides_; std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 2> b_k_c_xs_lengths_; std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 2> b_k_c_xs_strides_; std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor> ds_n_k_wos_lengths_; std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor> ds_n_k_wos_strides_; std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial + 2> e_n_k_wos_lengths_; std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 2> e_n_k_wos_strides_; std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_; std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_; std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_; std::array<index_t, NDimSpatial> input_left_pads_;
...@@ -906,7 +1043,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -906,7 +1043,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) *
arg.a_g_n_c_wis_lengths_[0];
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
...@@ -914,7 +1052,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -914,7 +1052,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle< const auto kernel = kernel_batch_gemm_multiple_d_xdl_cshuffle<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer, typename GridwiseGemm::DsGridPointer,
...@@ -927,6 +1065,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -927,6 +1065,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap, Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumDTensor>,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
...@@ -941,11 +1080,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -941,11 +1080,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.cde_element_op_, arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_); arg.block_2_etile_map_,
arg.compute_ptr_offset_of_batch_);
}; };
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
...@@ -991,6 +1132,10 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -991,6 +1132,10 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return false; return false;
} }
int itmp = 0;
printf("%d\n", itmp++);
// check ConvolutionForwardSpecialization // check ConvolutionForwardSpecialization
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
...@@ -998,7 +1143,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -998,7 +1143,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i) for(index_t i = 0; i < NDimSpatial; ++i)
{ {
const index_t X = arg.b_k_c_xs_lengths_[i + 2]; const index_t X = arg.b_g_k_c_xs_lengths_[i + 2];
const index_t ConvStride = arg.conv_filter_strides_[i]; const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i]; const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i]; const index_t RightPad = arg.input_right_pads_[i];
...@@ -1015,7 +1160,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -1015,7 +1160,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// check if it's 1x1 conv // check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i) for(index_t i = 0; i < NDimSpatial; ++i)
{ {
const index_t X = arg.b_k_c_xs_lengths_[i + 2]; const index_t X = arg.b_g_k_c_xs_lengths_[i + 2];
const index_t LeftPad = arg.input_left_pads_[i]; const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i]; const index_t RightPad = arg.input_right_pads_[i];
...@@ -1026,11 +1171,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -1026,11 +1171,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
} }
} }
printf("%d\n", itmp++);
// check vector access of A // check vector access of A
if constexpr(is_same_v<ALayout, ctc::NWC> || is_same_v<ALayout, ctc::NHWC> || if constexpr(is_same_v<ALayout, ctc::NWC> || is_same_v<ALayout, ctc::NHWC> ||
is_same_v<ALayout, ctc::NDHWC>) is_same_v<ALayout, ctc::NDHWC>)
{ {
const index_t C = arg.a_n_c_wis_lengths_[1]; const index_t C = arg.a_g_n_c_wis_lengths_[2];
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
{ {
...@@ -1042,11 +1189,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -1042,11 +1189,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return false; return false;
} }
printf("%d\n", itmp++);
// check vector access of B // check vector access of B
if constexpr(is_same_v<BLayout, ctc::KXC> || is_same_v<BLayout, ctc::KYXC> || if constexpr(is_same_v<BLayout, ctc::KXC> || is_same_v<BLayout, ctc::KYXC> ||
is_same_v<BLayout, ctc::KZYXC>) is_same_v<BLayout, ctc::KZYXC>)
{ {
const index_t C = arg.b_k_c_xs_lengths_[1]; const index_t C = arg.b_g_k_c_xs_lengths_[2];
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
{ {
...@@ -1058,6 +1207,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -1058,6 +1207,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return false; return false;
} }
printf("%d\n", itmp++);
// check vector access of Ds // check vector access of Ds
bool valid = true; bool valid = true;
...@@ -1068,7 +1219,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -1068,7 +1219,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
is_same_v<DLayout, ctc::NDHWK> || is_same_v<DLayout, ctc::NW_K> || is_same_v<DLayout, ctc::NDHWK> || is_same_v<DLayout, ctc::NW_K> ||
is_same_v<DLayout, ctc::NHW_K> || is_same_v<DLayout, ctc::NDHW_K>) is_same_v<DLayout, ctc::NHW_K> || is_same_v<DLayout, ctc::NDHW_K>)
{ {
const index_t K = arg.ds_n_k_wos_lengths_[i][1]; const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{ {
...@@ -1086,11 +1237,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -1086,11 +1237,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return false; return false;
} }
printf("%d\n", itmp++);
// check vector access of E // check vector access of E
if constexpr(is_same_v<ELayout, ctc::NWK> || is_same_v<ELayout, ctc::NHWK> || if constexpr(is_same_v<ELayout, ctc::NWK> || is_same_v<ELayout, ctc::NHWK> ||
is_same_v<ELayout, ctc::NDHWK>) is_same_v<ELayout, ctc::NDHWK>)
{ {
const index_t K = arg.e_n_k_wos_lengths_[1]; const index_t K = arg.e_g_n_k_wos_lengths_[2];
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{ {
...@@ -1102,6 +1255,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -1102,6 +1255,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return false; return false;
} }
printf("%d\n", itmp++);
// check Gridwise GEMM // check Gridwise GEMM
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
...@@ -1120,14 +1275,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -1120,14 +1275,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
const void* p_b, const void* p_b,
const std::array<const void*, NumDTensor>& p_ds, const std::array<const void*, NumDTensor>& p_ds,
void* p_e, void* p_e,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_lengths, const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_strides, const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
...@@ -1140,14 +1295,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -1140,14 +1295,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
p_b, p_b,
p_ds, p_ds,
p_e, p_e,
a_n_c_wis_lengths, a_g_n_c_wis_lengths,
a_n_c_wis_strides, a_g_n_c_wis_strides,
b_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_k_c_xs_strides, b_g_k_c_xs_strides,
ds_n_k_wos_lengths, ds_g_n_k_wos_lengths,
ds_n_k_wos_strides, ds_g_n_k_wos_strides,
e_n_k_wos_lengths, e_g_n_k_wos_lengths,
e_n_k_wos_strides, e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -1164,14 +1319,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -1164,14 +1319,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
const void* p_b, const void* p_b,
const std::array<const void*, NumDTensor>& p_ds, const std::array<const void*, NumDTensor>& p_ds,
void* p_e, void* p_e,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_lengths, const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_strides, const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
...@@ -1184,14 +1339,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -1184,14 +1339,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
p_b, p_b,
p_ds, p_ds,
p_e, p_e,
a_n_c_wis_lengths, a_g_n_c_wis_lengths,
a_n_c_wis_strides, a_g_n_c_wis_strides,
b_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_k_c_xs_strides, b_g_k_c_xs_strides,
ds_n_k_wos_lengths, ds_g_n_k_wos_lengths,
ds_n_k_wos_strides, ds_g_n_k_wos_strides,
e_n_k_wos_lengths, e_g_n_k_wos_lengths,
e_n_k_wos_strides, e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -89,97 +89,33 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -89,97 +89,33 @@ struct ReferenceConvFwd : public device::BaseOperator
{ {
using Argument = ReferenceConvFwd::Argument; using Argument = ReferenceConvFwd::Argument;
// FIXME: properly implement "TensorView" for doing transpose or refer to dimension by name
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
// tensor descriptor in NCHW/KXYC/NKHW dimensional order // tensor descriptor in NCHW/KXYC/NKHW dimensional order
HostTensorDescriptor in_desc = arg.input_.mDesc;
HostTensorDescriptor wei_desc = arg.weight_.mDesc;
HostTensorDescriptor out_desc = arg.output_.mDesc;
// input
if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NWC>)
{
in_desc = transpose_host_tensor_descriptor_given_new2old(
arg.input_.mDesc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NHWC>)
{
in_desc = transpose_host_tensor_descriptor_given_new2old(
arg.input_.mDesc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{
in_desc = transpose_host_tensor_descriptor_given_new2old(
arg.input_.mDesc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// weight
if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC>)
{
wei_desc = transpose_host_tensor_descriptor_given_new2old(
arg.weight_.mDesc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC>)
{
wei_desc = transpose_host_tensor_descriptor_given_new2old(
arg.weight_.mDesc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
wei_desc = transpose_host_tensor_descriptor_given_new2old(
arg.weight_.mDesc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// output
if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NWK>)
{
out_desc = transpose_host_tensor_descriptor_given_new2old(
arg.output_.mDesc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK>)
{
out_desc = transpose_host_tensor_descriptor_given_new2old(
arg.output_.mDesc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
out_desc = transpose_host_tensor_descriptor_given_new2old(
arg.output_.mDesc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
if constexpr(NumDimSpatial == 1) if constexpr(NumDimSpatial == 1)
{ {
auto f_ncw = [&](auto n, auto k, auto wo) { auto func = [&](auto g, auto n, auto k, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < wei_desc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{ {
for(std::size_t x = 0; x < wei_desc.GetLengths()[2]; ++x) for(std::size_t x = 0; x < arg.weight_.GetLengths()[3]; ++x)
{ {
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) + auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < in_desc.GetLengths()[2]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
// FIXME hacky
arg.in_element_op_( arg.in_element_op_(
v_in, v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
ck::type_convert<float>(
arg.input_
.mData[in_desc.GetOffsetFromMultiIndex(n, c, wi)]));
// FIXME hacky
arg.wei_element_op_( arg.wei_element_op_(
v_wei, v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x)));
ck::type_convert<float>(
arg.weight_
.mData[wei_desc.GetOffsetFromMultiIndex(k, c, x)]));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
...@@ -190,33 +126,32 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -190,33 +126,32 @@ struct ReferenceConvFwd : public device::BaseOperator
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
// FIXME hacky arg.output_(g, n, k, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_.mData[out_desc.GetOffsetFromMultiIndex({n, k, wo})] =
ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(func,
out_desc.GetLengths()[0], arg.output_.GetLengths()[0],
out_desc.GetLengths()[1], arg.output_.GetLengths()[1],
out_desc.GetLengths()[2])( arg.output_.GetLengths()[2],
arg.output_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NumDimSpatial == 2)
{ {
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto func = [&](auto g, auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < wei_desc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{ {
for(std::size_t y = 0; y < wei_desc.GetLengths()[2]; ++y) for(std::size_t y = 0; y < arg.weight_.GetLengths()[3]; ++y)
{ {
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) + auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < wei_desc.GetLengths()[3]; ++x) for(std::size_t x = 0; x < arg.weight_.GetLengths()[4]; ++x)
{ {
auto wi = auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
...@@ -224,26 +159,18 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -224,26 +159,18 @@ struct ReferenceConvFwd : public device::BaseOperator
static_cast<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 && if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < in_desc.GetLengths()[2] && ck::type_convert<std::size_t>(hi) < arg.input_.GetLengths()[3] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < in_desc.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
// FIXME hacky
arg.in_element_op_( arg.in_element_op_(
v_in, v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
ck::type_convert<float>(
arg.input_.mData[in_desc.GetOffsetFromMultiIndex(
n, c, hi, wi)]));
// FIXME hacky
arg.wei_element_op_( arg.wei_element_op_(
v_wei, v_wei, ck::type_convert<float>(arg.weight_(g, k, c, y, x)));
ck::type_convert<float>(
arg.weight_.mData[wei_desc.GetOffsetFromMultiIndex(
k, c, y, x)]));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
...@@ -255,39 +182,38 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -255,39 +182,38 @@ struct ReferenceConvFwd : public device::BaseOperator
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
// FIXME hacky arg.output_(g, n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_.mData[out_desc.GetOffsetFromMultiIndex({n, k, ho, wo})] =
ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(func,
out_desc.GetLengths()[0], arg.output_.GetLengths()[0],
out_desc.GetLengths()[1], arg.output_.GetLengths()[1],
out_desc.GetLengths()[2], arg.output_.GetLengths()[2],
out_desc.GetLengths()[3])( arg.output_.GetLengths()[3],
arg.output_.GetLengths()[4])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NumDimSpatial == 3)
{ {
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) { auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < wei_desc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{ {
for(std::size_t z = 0; z < wei_desc.GetLengths()[2]; ++z) for(std::size_t z = 0; z < arg.weight_.GetLengths()[3]; ++z)
{ {
auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) + auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t y = 0; y < wei_desc.GetLengths()[3]; ++y) for(std::size_t y = 0; y < arg.weight_.GetLengths()[4]; ++y)
{ {
auto hi = auto hi =
static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) + static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t x = 0; x < wei_desc.GetLengths()[4]; ++x) for(std::size_t x = 0; x < arg.weight_.GetLengths()[5]; ++x)
{ {
auto wi = auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
...@@ -295,29 +221,24 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -295,29 +221,24 @@ struct ReferenceConvFwd : public device::BaseOperator
static_cast<ck::long_index_t>(arg.in_left_pads_[2]); static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
if(di >= 0 && if(di >= 0 &&
ck::type_convert<std::size_t>(di) < ck::type_convert<std::size_t>(di) <
in_desc.GetLengths()[2] && arg.input_.GetLengths()[3] &&
hi >= 0 && hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) <
in_desc.GetLengths()[3] && arg.input_.GetLengths()[4] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < in_desc.GetLengths()[4]) ck::type_convert<std::size_t>(wi) <
arg.input_.GetLengths()[5])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
// FIXME hacky arg.in_element_op_(v_in,
arg.in_element_op_( ck::type_convert<float>(
v_in, arg.input_(g, n, c, di, hi, wi)));
ck::type_convert<float>(
arg.input_.mData[in_desc.GetOffsetFromMultiIndex(
n, c, di, hi, wi)]));
// FIXME hacky
arg.wei_element_op_( arg.wei_element_op_(
v_wei, v_wei,
ck::type_convert<float>( ck::type_convert<float>(arg.weight_(g, k, c, z, y, x)));
arg.weight_.mData[wei_desc.GetOffsetFromMultiIndex(
k, c, z, y, x)]));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
...@@ -330,17 +251,16 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -330,17 +251,16 @@ struct ReferenceConvFwd : public device::BaseOperator
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
// FIXME hacky arg.output_(g, n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_.mData[out_desc.GetOffsetFromMultiIndex({n, k, d_o, ho, wo})] =
ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(func,
out_desc.GetLengths()[0], arg.output_.GetLengths()[0],
out_desc.GetLengths()[1], arg.output_.GetLengths()[1],
out_desc.GetLengths()[2], arg.output_.GetLengths()[2],
out_desc.GetLengths()[3], arg.output_.GetLengths()[3],
out_desc.GetLengths()[4])( arg.output_.GetLengths()[4],
arg.output_.GetLengths()[5])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
......
...@@ -18,6 +18,7 @@ struct ConvParam ...@@ -18,6 +18,7 @@ struct ConvParam
{ {
ConvParam(); ConvParam();
ConvParam(ck::index_t n_dim, ConvParam(ck::index_t n_dim,
ck::index_t group_count,
ck::index_t n_batch, ck::index_t n_batch,
ck::index_t n_out_channels, ck::index_t n_out_channels,
ck::index_t n_in_channels, ck::index_t n_in_channels,
...@@ -29,6 +30,7 @@ struct ConvParam ...@@ -29,6 +30,7 @@ struct ConvParam
const std::vector<ck::index_t>& right_pads); const std::vector<ck::index_t>& right_pads);
ck::index_t num_dim_spatial_; ck::index_t num_dim_spatial_;
ck::index_t G_;
ck::index_t N_; ck::index_t N_;
ck::index_t K_; ck::index_t K_;
ck::index_t C_; ck::index_t C_;
...@@ -50,20 +52,22 @@ struct ConvParam ...@@ -50,20 +52,22 @@ struct ConvParam
template <typename InDataType, typename WeiDataType, typename OutDataType> template <typename InDataType, typename WeiDataType, typename OutDataType>
std::size_t GetByte() const std::size_t GetByte() const
{ {
// sizeof(InDataType) * (N * C * <input spatial lengths product>) + // sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
// sizeof(WeiDataType) * (K * C * <filter spatial lengths product>) + // sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
// sizeof(OutDataType) * (N * K * <output spatial lengths product>); // sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return sizeof(InDataType) * (N_ * C_ * return sizeof(InDataType) *
std::accumulate(std::begin(input_spatial_lengths_), (G_ * N_ * C_ *
std::end(input_spatial_lengths_), std::accumulate(std::begin(input_spatial_lengths_),
static_cast<std::size_t>(1), std::begin(input_spatial_lengths_) + num_dim_spatial_,
std::multiplies<std::size_t>())) + static_cast<std::size_t>(1),
sizeof(WeiDataType) * (K_ * C_ * std::multiplies<std::size_t>())) +
std::accumulate(std::begin(filter_spatial_lengths_), sizeof(WeiDataType) *
std::end(filter_spatial_lengths_), (G_ * K_ * C_ *
static_cast<std::size_t>(1), std::accumulate(std::begin(filter_spatial_lengths_),
std::multiplies<std::size_t>())) + std::begin(filter_spatial_lengths_) + num_dim_spatial_,
sizeof(OutDataType) * (N_ * K_ * static_cast<std::size_t>(1),
std::multiplies<std::size_t>())) +
sizeof(OutDataType) * (G_ * N_ * K_ *
std::accumulate(std::begin(output_spatial_lengths_), std::accumulate(std::begin(output_spatial_lengths_),
std::end(output_spatial_lengths_), std::end(output_spatial_lengths_),
static_cast<std::size_t>(1), static_cast<std::size_t>(1),
......
...@@ -256,6 +256,10 @@ struct Tensor ...@@ -256,6 +256,10 @@ struct Tensor
return *this; return *this;
} }
const std::vector<std::size_t>& GetLengths() const { return mDesc.GetLengths(); }
const std::vector<std::size_t>& GetStrides() const { return mDesc.GetStrides(); }
void SetZero() void SetZero()
{ {
for(auto& v : mData) for(auto& v : mData)
......
...@@ -10,6 +10,7 @@ namespace utils { ...@@ -10,6 +10,7 @@ namespace utils {
namespace conv { namespace conv {
ConvParam::ConvParam(ck::index_t n_dim, ConvParam::ConvParam(ck::index_t n_dim,
ck::index_t group_count,
ck::index_t n_batch, ck::index_t n_batch,
ck::index_t n_out_channels, ck::index_t n_out_channels,
ck::index_t n_in_channels, ck::index_t n_in_channels,
...@@ -20,6 +21,7 @@ ConvParam::ConvParam(ck::index_t n_dim, ...@@ -20,6 +21,7 @@ ConvParam::ConvParam(ck::index_t n_dim,
const std::vector<ck::index_t>& left_pads, const std::vector<ck::index_t>& left_pads,
const std::vector<ck::index_t>& right_pads) const std::vector<ck::index_t>& right_pads)
: num_dim_spatial_(n_dim), : num_dim_spatial_(n_dim),
G_(group_count),
N_(n_batch), N_(n_batch),
K_(n_out_channels), K_(n_out_channels),
C_(n_in_channels), C_(n_in_channels),
...@@ -57,7 +59,7 @@ ConvParam::ConvParam(ck::index_t n_dim, ...@@ -57,7 +59,7 @@ ConvParam::ConvParam(ck::index_t n_dim,
} }
ConvParam::ConvParam() ConvParam::ConvParam()
: ConvParam::ConvParam(2, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}) : ConvParam::ConvParam(2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1})
{ {
} }
...@@ -68,14 +70,14 @@ std::vector<ck::index_t> ConvParam::GetOutputSpatialLengths() const ...@@ -68,14 +70,14 @@ std::vector<ck::index_t> ConvParam::GetOutputSpatialLengths() const
std::size_t ConvParam::GetFlops() const std::size_t ConvParam::GetFlops() const
{ {
// 2 * N * K * C * <output spatial lengths product> * <filter spatial lengths product> // 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return static_cast<std::size_t>(2) * N_ * K_ * C_ * return static_cast<std::size_t>(2) * G_ * N_ * K_ * C_ *
std::accumulate(std::begin(output_spatial_lengths_), std::accumulate(std::begin(output_spatial_lengths_),
std::end(output_spatial_lengths_), std::begin(output_spatial_lengths_) + num_dim_spatial_,
static_cast<std::size_t>(1), static_cast<std::size_t>(1),
std::multiplies<std::size_t>()) * std::multiplies<std::size_t>()) *
std::accumulate(std::begin(filter_spatial_lengths_), std::accumulate(std::begin(filter_spatial_lengths_),
std::end(filter_spatial_lengths_), std::begin(filter_spatial_lengths_) + num_dim_spatial_,
static_cast<std::size_t>(1), static_cast<std::size_t>(1),
std::multiplies<std::size_t>()); std::multiplies<std::size_t>());
} }
...@@ -87,13 +89,14 @@ std::size_t ConvParam::GetFlops() const ...@@ -87,13 +89,14 @@ std::size_t ConvParam::GetFlops() const
std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p) std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p)
{ {
os << "ConvParam {" os << "ConvParam {"
<< "\nnum_dim_spatial: " << p.num_dim_spatial_ << "\nN: " << p.N_ << "\nK: " << p.K_ << "\nnum_dim_spatial: " << p.num_dim_spatial_ << "\nG: " << p.G_ << "\nN: " << p.N_
<< "\nC: " << p.C_ << "\nfilter_spatial_lengths: " << p.filter_spatial_lengths_ << "\nK: " << p.K_ << "\nC: " << p.C_
<< "\nfilter_spatial_lengths: " << p.filter_spatial_lengths_
<< "\ninput_spatial_lengths: " << p.input_spatial_lengths_ << "\ninput_spatial_lengths: " << p.input_spatial_lengths_
<< "\nconv_filter_strides: " << p.conv_filter_strides_ << "\nconv_filter_strides: " << p.conv_filter_strides_
<< "\nconv_filter_dilations: " << p.conv_filter_dilations_ << "\nconv_filter_dilations: " << p.conv_filter_dilations_
<< "\ninput_left_pads: " << p.input_left_pads_ << "\ninput_left_pads: " << p.input_left_pads_
<< "\ninput_right_pads: " << p.input_right_pads_; << "\ninput_right_pads: " << p.input_right_pads_ << "}\n";
return os; return os;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment