Commit a5011336 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 33975236
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <vector>
#include "ck/tensor_description/tensor_descriptor.hpp"
template <typename... Ts>
std::ostream& operator<<(std::ostream& os, const ck::TensorDescriptor<Ts...>& desc)
{
constexpr ck::index_t nDim = ck::remove_cvref_t<decltype(desc)>::GetNumOfDimension();
os << "{";
ck::static_for<0, nDim - 1, 1>{}([&](auto i) { os << desc.GetLength(i) << ", "; });
os << desc.GetLength(ck::Number<nDim - 1>{});
os << "}";
return os;
}
......@@ -20,6 +20,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/device_utility/io.hpp"
namespace ck {
namespace tensor_operation {
......@@ -84,12 +85,6 @@ __global__ void
ignore = b_element_op;
ignore = cde_element_op;
ignore = a_grid_desc_ak0_m_ak1;
// input : input image A[N, C, Hi, Wi],
// input : weight B[K, C, Y, X],
// input : D0[N, K, Ho, Wo], D1[N, K, Ho, Wo], ...
// output : output image E[N, K, Ho, Wo]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
......@@ -172,8 +167,6 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
BElementwiseOperation,
CDEElementwiseOperation>
{
namespace ctc = ck::tensor_layout::convolution;
using DeviceOp = DeviceConvFwdMultipleD_Xdl_CShuffle;
static constexpr index_t NumDTensor = DsDataType::Size();
......@@ -189,7 +182,9 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay, typename std::enable_if<is_same_v<ALay, ctc::NWC>, bool>::type = false>
template <typename ALay,
typename std::enable_if<is_same_v<ALay, tensor_layout::convolution::NWC>,
bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
......@@ -299,7 +294,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
template <typename ALay,
typename std::enable_if<is_same_v<ALay, ctc::NHWC>, bool>::type = false>
typename std::enable_if<is_same_v<ALay, tensor_layout::convolution::NHWC>,
bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
......@@ -423,7 +419,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
template <typename ALay,
typename std::enable_if<is_same_v<ALay, ctc::NDHWC>, bool>::type = false>
typename std::enable_if<is_same_v<ALay, tensor_layout::convolution::NDHWC>,
bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
......@@ -570,11 +567,24 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// KYXC, K_YXC
// KZYXC, K_ZYXC
template <typename BLay,
typename std::enable_if<is_same_v<BLay, ctc::KXC> || is_same_v<BLay, ctc::KYXC> ||
is_same_v<BLay, ctc::KZYXC>,
typename std::enable_if<is_same_v<BLay, tensor_layout::convolution::KXC> ||
is_same_v<BLay, tensor_layout::convolution::KYXC> ||
is_same_v<BLay, tensor_layout::convolution::KZYXC>,
bool>::type = false>
static auto MakeBGridDescriptor_N_K(index_t GemmNRaw, index_t GemmKRaw)
static auto
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides)
{
const index_t K = b_k_c_xs_lengths[0];
const index_t C = b_k_c_xs_lengths[1];
const index_t GemmNRaw = K;
const index_t GemmKRaw = C * std::accumulate(b_k_c_xs_lengths.begin() + 2,
b_k_c_xs_lengths.begin() + 2 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto wei_k_yxc_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmNRaw, GemmKRaw));
......@@ -585,37 +595,16 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
template <typename ELay,
typename std::enable_if<is_same_v<ELay, ctc::NWK> || is_same_v<ELay, ctc::NHWK> ||
is_same_v<ELay, ctc::NDHWK>,
typename std::enable_if<is_same_v<ELay, tensor_layout::convolution::NWK> ||
is_same_v<ELay, tensor_layout::convolution::NHWK> ||
is_same_v<ELay, tensor_layout::convolution::NDHWK>,
bool>::type = false>
static auto MakeEGridDescriptor_M_N(index_t GemmMRaw, index_t GemmN)
{
const index_t GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
const auto out_gemmmraw_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmM, GemmN));
const auto out_gemmm_gemmn_grid_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmn_grid_desc);
return out_gemmm_gemmn_grid_desc;
}
static auto
MakeABEGridDescriptors(const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides)
{
const index_t N = a_n_c_wis_lengths[0];
const index_t K = b_k_c_xs_lengths[0];
const index_t C = a_n_c_wis_lengths[1];
const index_t N = e_n_k_wos_lengths[0];
const index_t K = e_n_k_wos_lengths[1];
const index_t GemmMRaw = N * std::accumulate(e_n_k_wos_lengths.begin() + 2,
e_n_k_wos_lengths.begin() + 2 + NDimSpatial,
......@@ -624,42 +613,22 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
const index_t GemmNRaw = K;
const index_t GemmKRaw = C * std::accumulate(b_k_c_xs_lengths.begin() + 2,
b_k_c_xs_lengths.begin() + 2 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
// A:
const auto in_gemmm_gemmk_grid_desc =
MakeAGridDescriptor_M_K<ALayout>(a_n_c_wis_lengths,
a_n_c_wis_strides,
b_k_c_xs_lengths,
b_k_c_xs_strides,
e_n_k_wos_lengths,
e_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
// B:
const auto wei_gemmn_gemmk_grid_desc = MakeBGridDescriptor_N_K<BLayout>(GemmNRaw, GemmKRaw);
const auto out_gemmmraw_gemmnraw_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(GemmMRaw, GemmNRaw));
// E:
const auto out_gemmm_gemmn_grid_desc = MakeEGridDescriptor_M_N<ELayout>(GemmMRaw, GemmNRaw);
const auto out_gemmm_gemmn_grid_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_grid_desc);
return make_tuple(
in_gemmm_gemmk_grid_desc, wei_gemmn_gemmk_grid_desc, out_gemmm_gemmn_grid_desc);
return out_gemmm_gemmn_grid_desc;
}
using ABEGridDescs = decltype(MakeABEGridDescriptors({}, {}, {}, {}, {}, {}, {}, {}, {}, {}));
using AGridDesc_M_K = remove_cvref_t<decltype(ABEGridDescs{}[I0])>;
using BGridDesc_N_K = remove_cvref_t<decltype(ABEGridDescs{}[I1])>;
using EGridDesc_M_N = remove_cvref_t<decltype(ABEGridDescs{}[I2])>;
using AGridDesc_M_K = remove_cvref_t<decltype(
MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
......@@ -739,9 +708,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
p_b_grid_{static_cast<const BDataType*>(p_b)},
p_ds_grid_{}, // FIXME
p_e_grid_{static_cast<EDataType*>(p_e)},
a_grid_desc_m_k_{},
b_grid_desc_n_k_{},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{},
a_grid_desc_ak0_m_ak1_{},
b_grid_desc_bk0_n_bk1_{},
e_grid_desc_m_n_{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{},
a_element_op_{a_element_op},
......@@ -760,7 +733,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
const auto descs = DeviceOp::MakeABEGridDescriptors(a_n_c_wis_lengths,
a_grid_desc_m_k_ = DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_n_c_wis_lengths,
a_n_c_wis_strides,
b_k_c_xs_lengths,
b_k_c_xs_strides,
......@@ -771,21 +744,22 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
input_left_pads,
input_right_pads);
const auto a_grid_desc_m_k = descs[I0];
const auto b_grid_desc_n_k = descs[I1];
e_grid_desc_m_n_ = descs[I2];
b_grid_desc_n_k_ =
DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_k_c_xs_lengths, b_k_c_xs_strides);
e_grid_desc_m_n_ =
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_n_k_wos_lengths, e_n_k_wos_strides);
a_grid_desc_ak0_m_ak1_ =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_);
b_grid_desc_bk0_n_bk1_ =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_);
block_2_etile_map_ = Block2ETileMap{e_grid_desc_m_n_};
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
e_grid_desc_m_n_,
block_2_etile_map_))
if(GridwiseGemm::CheckValidity(
a_grid_desc_m_k_, b_grid_desc_n_k_, e_grid_desc_m_n_, block_2_etile_map_))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......@@ -801,14 +775,19 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
EDataType* p_e_grid_;
// tensor descriptors
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
// FIXME: don't assume D and E desc are the same type
StaticallyIndexedArray<EGridDesc_M_N, NumDTensor> ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
// FIXME: don't assume D and E desc are the same type
StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N e_grid_desc_m_n_;
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
......@@ -844,27 +823,18 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
{
#if 1
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "A[M, K]: " << arg.a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << arg.b_grid_desc_n_k_ << std::endl;
std::cout << "E[M, N]: " << arg.e_grid_desc_m_n_ << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle has invalid setting");
"wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting");
}
const index_t grid_size =
......@@ -931,6 +901,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
static bool IsSupportedArgument(const Argument& arg)
{
namespace ctc = tensor_layout::convolution;
// check device
if(get_device_name() == "gfx908")
{
......@@ -1049,8 +1021,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
......
......@@ -70,7 +70,7 @@ template <typename FloatAB,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched>
struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
struct GridwiseGemmMultipleD_xdl_cshuffle
{
static constexpr index_t NumDTensor = DsDataType::Size();
......@@ -222,10 +222,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename AGridDesc_AK0_M_AK1, typename BGridDesc_BK0_N_BK1, typename Block2ETileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
template <typename Block2ETileMap>
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
......@@ -233,9 +232,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
return false;
......@@ -271,7 +270,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using DefaultBGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment