Commit 12585e57 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 90acba1d
......@@ -142,22 +142,105 @@ int run_conv_fwd(bool do_verification,
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
// tensor descriptor in NCHW/KXYC/NKHW dimensional order
HostTensorDescriptor in_n_c_wis_desc = in_desc;
HostTensorDescriptor wei_k_c_xs_desc = wei_desc;
HostTensorDescriptor out_n_k_wos_desc = out_desc;
// input
if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWC>)
{
in_n_c_wis_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NHWC>)
{
in_n_c_wis_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{
in_n_c_wis_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// weight
if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC>)
{
wei_k_c_xs_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC>)
{
wei_k_c_xs_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
wei_k_c_xs_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// output
if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NWK>)
{
out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK>)
{
out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
std::array<ck::index_t, NDimSpatial + 2> a_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 2> a_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 2> b_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 2> b_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 2> e_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 2> e_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); };
copy(in_n_c_wis_desc.GetLengths(), a_n_c_wis_lengths);
copy(in_n_c_wis_desc.GetStrides(), a_n_c_wis_strides);
copy(wei_k_c_xs_desc.GetLengths(), b_k_c_xs_lengths);
copy(wei_k_c_xs_desc.GetStrides(), b_k_c_xs_strides);
copy(out_n_k_wos_desc.GetLengths(), e_n_k_wos_lengths);
copy(out_n_k_wos_desc.GetStrides(), e_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
// do GEMM
auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param.N_,
conv_param.K_,
conv_param.C_,
conv_param.input_spatial_lengths_,
conv_param.filter_spatial_lengths_,
conv_param.GetOutputSpatialLengths(),
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{},
out_device_buf.GetDeviceBuffer(),
a_n_c_wis_lengths,
a_n_c_wis_strides,
b_k_c_xs_lengths,
b_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 2>, 0>{{}},
std::array<std::array<ck::index_t, NDimSpatial + 2>, 0>{{}},
e_n_k_wos_lengths,
e_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
......
......@@ -20,7 +20,7 @@ namespace device {
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <ck::index_t NDimSpatial,
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
......@@ -36,23 +36,26 @@ struct DeviceConvFwdMultipleD : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const ADataType* p_a,
const BDataType* p_b,
EDataType* p_e,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -181,59 +181,38 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <
typename BLayout_,
typename std::enable_if<is_same_v<BLayout_, ck::tensor_layout::convolution::KXC> ||
is_same_v<BLayout_, ck::tensor_layout::convolution::KYXC> ||
is_same_v<BLayout_, ck::tensor_layout::convolution::KZYXC>,
bool>::type = false>
static auto MakeBGridDescriptor_N_K(index_t GemmNRaw, index_t GemmKRaw)
template <typename ALay,
typename std::enable_if<is_same_v<ALay, tensor_layout::convolution::NWC>,
bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const auto wei_k_yxc_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmNRaw, GemmKRaw));
const auto wei_gemmn_gemmk_grid_desc =
matrix_padder.PadBDescriptor_N_K(wei_k_yxc_grid_desc);
return wei_gemmn_gemmk_grid_desc;
}
const index_t N = a_n_c_wis_lengths[0];
const index_t C = a_n_c_wis_lengths[1];
template <
typename ELayout_,
typename std::enable_if<is_same_v<ELayout_, ck::tensor_layout::convolution::NWK> ||
is_same_v<ELayout_, ck::tensor_layout::convolution::NHWK> ||
is_same_v<ELayout_, ck::tensor_layout::convolution::NDHWK>,
bool>::type = false>
static auto MakeEGridDescriptor_M_N(index_t GemmMRaw, index_t GemmN)
{
const index_t GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
const index_t GemmMRaw = N * std::accumulate(e_n_k_wos_lengths.begin() + 2,
e_n_k_wos_lengths.begin() + 3,
index_t{1},
std::multiplies<index_t>());
const auto out_gemmmraw_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmM, GemmN));
const index_t GemmKRaw = C * std::accumulate(b_k_c_xs_lengths.begin() + 2,
b_k_c_xs_lengths.begin() + 3,
index_t{1},
std::multiplies<index_t>());
const auto out_gemmm_gemmn_grid_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmn_grid_desc);
const index_t Wi = a_n_c_wis_lengths[2];
return out_gemmm_gemmn_grid_desc;
}
const index_t Wo = e_n_k_wos_lengths[2];
template <typename ALayout_,
typename std::enable_if<is_same_v<ALayout_, ck::tensor_layout::convolution::NWC>,
bool>::type = false>
static auto MakeAGridDescriptor_M_K(index_t N,
index_t C,
index_t GemmMRaw,
index_t GemmKRaw,
const std::vector<index_t>& input_spatial_lengths,
const std::vector<index_t>& filter_spatial_lengths,
const std::vector<index_t>& output_spatial_lengths,
const std::vector<index_t>& conv_filter_strides,
const std::vector<index_t>& conv_filter_dilations,
const std::vector<index_t>& input_left_pads,
const std::vector<index_t>& input_right_pads)
{
const index_t Wi = input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
......@@ -274,7 +253,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
else
{
const index_t X = filter_spatial_lengths[0];
const index_t X = b_k_c_xs_lengths[2];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
......@@ -313,26 +292,39 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
template <typename ALayout_,
typename std::enable_if<is_same_v<ALayout_, ck::tensor_layout::convolution::NHWC>,
template <typename ALay,
typename std::enable_if<is_same_v<ALay, tensor_layout::convolution::NHWC>,
bool>::type = false>
static auto MakeAGridDescriptor_M_K(index_t N,
index_t C,
index_t GemmMRaw,
index_t GemmKRaw,
const std::vector<index_t>& input_spatial_lengths,
const std::vector<index_t>& filter_spatial_lengths,
const std::vector<index_t>& output_spatial_lengths,
const std::vector<index_t>& conv_filter_strides,
const std::vector<index_t>& conv_filter_dilations,
const std::vector<index_t>& input_left_pads,
const std::vector<index_t>& input_right_pads)
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t N = a_n_c_wis_lengths[0];
const index_t C = a_n_c_wis_lengths[1];
const index_t GemmMRaw = N * std::accumulate(e_n_k_wos_lengths.begin() + 2,
e_n_k_wos_lengths.begin() + 4,
index_t{1},
std::multiplies<index_t>());
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
const index_t GemmKRaw = C * std::accumulate(b_k_c_xs_lengths.begin() + 2,
b_k_c_xs_lengths.begin() + 4,
index_t{1},
std::multiplies<index_t>());
const index_t Hi = a_n_c_wis_lengths[2];
const index_t Wi = a_n_c_wis_lengths[3];
const index_t Ho = e_n_k_wos_lengths[2];
const index_t Wo = e_n_k_wos_lengths[3];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
......@@ -377,8 +369,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
else
{
const index_t Y = filter_spatial_lengths[0];
const index_t X = filter_spatial_lengths[1];
const index_t Y = b_k_c_xs_lengths[2];
const index_t X = b_k_c_xs_lengths[3];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
......@@ -425,28 +417,41 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
template <typename ALayout_,
typename std::enable_if<is_same_v<ALayout_, ck::tensor_layout::convolution::NDHWC>,
template <typename ALay,
typename std::enable_if<is_same_v<ALay, tensor_layout::convolution::NDHWC>,
bool>::type = false>
static auto MakeAGridDescriptor_M_K(index_t N,
index_t C,
index_t GemmMRaw,
index_t GemmKRaw,
const std::vector<index_t>& input_spatial_lengths,
const std::vector<index_t>& filter_spatial_lengths,
const std::vector<index_t>& output_spatial_lengths,
const std::vector<index_t>& conv_filter_strides,
const std::vector<index_t>& conv_filter_dilations,
const std::vector<index_t>& input_left_pads,
const std::vector<index_t>& input_right_pads)
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t N = a_n_c_wis_lengths[0];
const index_t C = a_n_c_wis_lengths[1];
const index_t GemmMRaw = N * std::accumulate(e_n_k_wos_lengths.begin() + 2,
e_n_k_wos_lengths.begin() + 5,
index_t{1},
std::multiplies<index_t>());
const index_t GemmKRaw = C * std::accumulate(b_k_c_xs_lengths.begin() + 2,
b_k_c_xs_lengths.begin() + 5,
index_t{1},
std::multiplies<index_t>());
const index_t Di = a_n_c_wis_lengths[2];
const index_t Hi = a_n_c_wis_lengths[3];
const index_t Wi = a_n_c_wis_lengths[4];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
const index_t Do = e_n_k_wos_lengths[2];
const index_t Ho = e_n_k_wos_lengths[3];
const index_t Wo = e_n_k_wos_lengths[4];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
......@@ -495,9 +500,9 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
else
{
const index_t Z = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[2];
const index_t Z = b_k_c_xs_lengths[2];
const index_t Y = b_k_c_xs_lengths[3];
const index_t X = b_k_c_xs_lengths[4];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
......@@ -556,49 +561,80 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
static index_t GetGemmMRaw(index_t N, const std::vector<index_t>& output_spatial_lengths)
// supported layout:
// KXC, K_XC
// KYXC, K_YXC
// KZYXC, K_ZYXC
template <typename BLay,
typename std::enable_if<is_same_v<BLay, tensor_layout::convolution::KXC> ||
is_same_v<BLay, tensor_layout::convolution::KYXC> ||
is_same_v<BLay, tensor_layout::convolution::KZYXC>,
bool>::type = false>
static auto MakeBGridDescriptor_N_K(index_t GemmNRaw, index_t GemmKRaw)
{
return N * std::accumulate(std::begin(output_spatial_lengths),
std::end(output_spatial_lengths),
1,
std::multiplies<index_t>());
const auto wei_k_yxc_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmNRaw, GemmKRaw));
const auto wei_gemmn_gemmk_grid_desc =
matrix_padder.PadBDescriptor_N_K(wei_k_yxc_grid_desc);
return wei_gemmn_gemmk_grid_desc;
}
static index_t GetGemmKRaw(index_t C, const std::vector<index_t>& filter_spatial_lengths)
template <typename ELay,
typename std::enable_if<is_same_v<ELay, tensor_layout::convolution::NWK> ||
is_same_v<ELay, tensor_layout::convolution::NHWK> ||
is_same_v<ELay, tensor_layout::convolution::NDHWK>,
bool>::type = false>
static auto MakeEGridDescriptor_M_N(index_t GemmMRaw, index_t GemmN)
{
return C * std::accumulate(std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths),
1,
std::multiplies<index_t>());
const index_t GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
const auto out_gemmmraw_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmM, GemmN));
const auto out_gemmm_gemmn_grid_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmn_grid_desc);
return out_gemmm_gemmn_grid_desc;
}
static auto
MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(index_t N,
index_t K,
index_t C,
std::vector<index_t> input_spatial_lengths,
std::vector<index_t> filter_spatial_lengths,
std::vector<index_t> output_spatial_lengths,
std::vector<index_t> conv_filter_strides,
std::vector<index_t> conv_filter_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads)
MakeABEGridDescriptors(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
using namespace ck;
const index_t N = a_n_c_wis_lengths[0];
const index_t K = b_k_c_xs_lengths[0];
const index_t C = a_n_c_wis_lengths[1];
const index_t GemmMRaw = N * std::accumulate(e_n_k_wos_lengths.begin() + 2,
e_n_k_wos_lengths.begin() + 2 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const index_t GemmMRaw = GetGemmMRaw(N, output_spatial_lengths);
const index_t GemmNRaw = K;
const index_t GemmKRaw = GetGemmKRaw(C, filter_spatial_lengths);
const index_t GemmKRaw = C * std::accumulate(b_k_c_xs_lengths.begin() + 2,
b_k_c_xs_lengths.begin() + 2 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
// A:
const auto in_gemmm_gemmk_grid_desc =
MakeAGridDescriptor_M_K<ALayout>(N,
C,
GemmMRaw,
GemmKRaw,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
MakeAGridDescriptor_M_K<ALayout>(a_n_c_wis_lengths,
a_n_c_wis_strides,
b_k_c_xs_lengths,
b_k_c_xs_strides,
e_n_k_wos_lengths,
e_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......@@ -614,28 +650,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
in_gemmm_gemmk_grid_desc, wei_gemmn_gemmk_grid_desc, out_gemmm_gemmn_grid_desc);
}
template <index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetABEGridDesc()
{
return MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1});
}
template <index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetABEGridDesc()
{
return MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
}
template <index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetABEGridDesc()
{
return MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
}
using ABEGridDescs = decltype(GetABEGridDesc<NDimSpatial>());
using ABEGridDescs = decltype(MakeABEGridDescriptors({}, {}, {}, {}, {}, {}, {}, {}, {}, {}));
using AGridDesc_M_K = remove_cvref_t<decltype(ABEGridDescs{}[I0])>;
using BGridDesc_N_K = remove_cvref_t<decltype(ABEGridDescs{}[I1])>;
......@@ -698,53 +713,61 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_in_grid,
const BDataType* p_wei_grid,
EDataType* p_out_grid,
index_t N,
index_t K,
index_t C,
std::vector<index_t> input_spatial_lengths,
std::vector<index_t> filter_spatial_lengths,
std::vector<index_t> output_spatial_lengths,
std::vector<index_t> conv_filter_strides,
std::vector<index_t> conv_filter_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads,
AElementwiseOperation in_element_op,
BElementwiseOperation wei_element_op,
CDEElementwiseOperation out_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_in_grid)},
p_b_grid_{static_cast<const BDataType*>(p_wei_grid)},
Argument(
const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a)},
p_b_grid_{static_cast<const BDataType*>(p_b)},
p_ds_grid_{}, // FIXME
p_e_grid_{static_cast<EDataType*>(p_out_grid)},
p_e_grid_{static_cast<EDataType*>(p_e)},
a_grid_desc_ak0_m_ak1_{},
b_grid_desc_bk0_n_bk1_{},
e_grid_desc_m_n_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{},
a_element_op_{in_element_op},
b_element_op_{wei_element_op},
cde_element_op_{out_element_op},
Conv_N_{N},
Conv_K_{K},
Conv_C_{C},
filter_spatial_lengths_{filter_spatial_lengths},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_n_c_wis_lengths_{a_n_c_wis_lengths},
a_n_c_wis_strides_{a_n_c_wis_strides},
b_k_c_xs_lengths_{b_k_c_xs_lengths},
b_k_c_xs_strides_{b_k_c_xs_strides},
ds_n_k_wos_lengths_{ds_n_k_wos_lengths},
ds_n_k_wos_strides_{ds_n_k_wos_strides},
e_n_k_wos_lengths_{e_n_k_wos_lengths},
e_n_k_wos_strides_{e_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
const auto descs =
DeviceOp::MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
const auto descs = DeviceOp::MakeABEGridDescriptors(a_n_c_wis_lengths,
a_n_c_wis_strides,
b_k_c_xs_lengths,
b_k_c_xs_strides,
e_n_k_wos_lengths,
e_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
const auto a_grid_desc_m_k = descs[I0];
const auto b_grid_desc_n_k = descs[I1];
......@@ -796,13 +819,18 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
CDEElementwiseOperation cde_element_op_;
// for checking IsSupportedArgument()
index_t Conv_N_;
index_t Conv_K_;
index_t Conv_C_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> input_left_pads_;
std::vector<index_t> input_right_pads_;
std::array<index_t, NDimSpatial + 2> a_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 2> a_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 2> b_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 2> b_k_c_xs_strides_;
std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor> ds_n_k_wos_lengths_;
std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor> ds_n_k_wos_strides_;
std::array<index_t, NDimSpatial + 2> e_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 2> e_n_k_wos_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
};
// Invoker
......@@ -856,7 +884,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
ck::StaticallyIndexedArray<
StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -905,21 +933,9 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
static bool IsSupportedArgument(const Argument& arg)
{
#if 1
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{" << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0)
<< ", " << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{" << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0)
<< ", " << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
namespace ctc = tensor_layout::convolution;
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(ck::get_device_name() == "gfx908")
if(get_device_name() == "gfx908")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t>))
......@@ -927,7 +943,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return false;
}
}
else if(ck::get_device_name() == "gfx90a")
else if(get_device_name() == "gfx90a")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
......@@ -940,8 +956,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return false;
}
// tensors can't be bigger than 2GB each.
constexpr ck::long_index_t GB2 = (ck::long_index_t{1} << 31);
// check tensor size: can't be larger than 2GB each
constexpr long_index_t GB2 = (long_index_t{1} << 31);
if(arg.a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() * sizeof(ADataType) > GB2 ||
arg.b_grid_desc_bk0_n_bk1_.GetElementSpaceSize() * sizeof(BDataType) > GB2 ||
......@@ -950,14 +966,19 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return false;
}
// check ConvolutionForwardSpecialization
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
const index_t X = arg.b_k_c_xs_lengths_[i + 2];
const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
......@@ -969,24 +990,63 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.input_left_pads_[i] == 0 &&
arg.input_right_pads_[i] == 0))
const index_t X = arg.b_k_c_xs_lengths_[i + 2];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
}
}
// vector load A/B matrix from global memory
if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 &&
arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
// check vector access of A
if constexpr(is_same_v<ALayout, ctc::NWC> || is_same_v<ALayout, ctc::NHWC> ||
is_same_v<ALayout, ctc::NDHWC>)
{
const index_t C = arg.a_n_c_wis_lengths_[1];
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
return false;
}
// vector store D/E matrix into global memory
if(!(arg.Conv_K_ % CDEBlockTransferScalarPerVector_NPerBlock == 0))
// check vector access of B
if constexpr(is_same_v<BLayout, ctc::KXC> || is_same_v<BLayout, ctc::KYXC> ||
is_same_v<BLayout, ctc::KZYXC>)
{
const index_t C = arg.b_k_c_xs_lengths_[1];
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
return false;
}
// FIXME: check vector access of Ds
// check vector access of E
if constexpr(is_same_v<ELayout, ctc::NWK> || is_same_v<ELayout, ctc::NHWK> ||
is_same_v<ELayout, ctc::NDHWK>)
{
const index_t K = arg.e_n_k_wos_lengths_[1];
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
}
else
{
return false;
}
......@@ -1003,77 +1063,90 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_in_grid,
const BDataType* p_wei_grid,
EDataType* p_out_grid,
index_t N,
index_t K,
index_t C,
std::vector<index_t> input_spatial_lengths,
std::vector<index_t> filter_spatial_lengths,
std::vector<index_t> output_spatial_lengths,
std::vector<index_t> conv_filter_strides,
std::vector<index_t> conv_filter_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads,
AElementwiseOperation in_element_op,
BElementwiseOperation wei_element_op,
CDEElementwiseOperation out_element_op)
static auto MakeArgument(
const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
return Argument{p_in_grid,
p_wei_grid,
p_out_grid,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
return Argument{p_a,
p_b,
p_ds,
p_e,
a_n_c_wis_lengths,
a_n_c_wis_strides,
b_k_c_xs_lengths,
b_k_c_xs_strides,
ds_n_k_wos_lengths,
ds_n_k_wos_strides,
e_n_k_wos_lengths,
e_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op};
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const ADataType* p_in_grid,
const BDataType* p_wei_grid,
EDataType* p_out_grid,
index_t N,
index_t K,
index_t C,
std::vector<index_t> input_spatial_lengths,
std::vector<index_t> filter_spatial_lengths,
std::vector<index_t> output_spatial_lengths,
std::vector<index_t> conv_filter_strides,
std::vector<index_t> conv_filter_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads,
AElementwiseOperation in_element_op,
BElementwiseOperation wei_element_op,
CDEElementwiseOperation out_element_op) override
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_in_grid),
static_cast<const BDataType*>(p_wei_grid),
static_cast<EDataType*>(p_out_grid),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_n_c_wis_lengths,
a_n_c_wis_strides,
b_k_c_xs_lengths,
b_k_c_xs_strides,
ds_n_k_wos_lengths,
ds_n_k_wos_strides,
e_n_k_wos_lengths,
e_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment