Unverified Commit 27858374 authored by Shaojie WANG's avatar Shaojie WANG Committed by GitHub
Browse files

Conv bwd data multiple d (#404)



* init commit of convnd bwd data

* begin compiling example

* have a first version that produce a right result

* refine device level launch kernel code

* add more instances in example and get right results

* clang-format

* format example file

* add more instances

* fix instances

* adding conv_bwd_data multile_d

* adding conv_bwd_data multile_d

* adding conv_bwd multiple d

* adding conv_bwd multiple d

* adding conv_bwd multiple d

* refactor

* refactor

* adding conv bwd data multiple d

* adding conv bwd data multiple d

* adding conv bwd data multiple d

* adding conv bwd data multiple d

* adding conv bwd data multiple d

* adding conv bwd data multiple d

* adding conv bwd data multiple d

* refactor

* update conv fwd's bias impl

* refactor

* reorg file

* clean up cmake

* clean

* clean

* clean
Co-authored-by: default avatarChao Liu <lc.roy86@gmail.com>
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 43c898f6
...@@ -117,7 +117,7 @@ __global__ void ...@@ -117,7 +117,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batch_gemm_multiple_d_xdl_cshuffle( kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
...@@ -136,8 +136,7 @@ __global__ void ...@@ -136,8 +136,7 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
// offset base pointer for each work-group
#if 1
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
...@@ -174,24 +173,6 @@ __global__ void ...@@ -174,24 +173,6 @@ __global__ void
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock_, e_grid_desc_mblock_mperblock_nblock_nperblock_,
block_2_ctile_map); block_2_ctile_map);
#else
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock_,
block_2_ctile_map);
#endif
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -378,6 +359,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -378,6 +359,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
// desc for problem definition
using AGridDesc_M_K = remove_cvref_t<decltype( using AGridDesc_M_K = remove_cvref_t<decltype(
MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>; using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
...@@ -395,10 +377,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -395,10 +377,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -432,12 +410,19 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -432,12 +410,19 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; // block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -467,6 +452,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -467,6 +452,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
p_b_grid_{static_cast<const BDataType*>(p_b)}, p_b_grid_{static_cast<const BDataType*>(p_b)},
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)}, p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths, a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
...@@ -561,6 +547,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -561,6 +547,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
EDataType* p_e_grid_; EDataType* p_e_grid_;
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_;
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
...@@ -569,14 +556,14 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -569,14 +556,14 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map // block-to-e-tile map
Block2ETileMap block_2_etile_map_; Block2ETileMap block_2_etile_map_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_;
// element-wise op // element-wise op
...@@ -622,8 +609,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -622,8 +609,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_;
arg.a_g_n_c_wis_lengths_[0]; // Group count
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
...@@ -631,7 +617,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -631,7 +617,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_batch_gemm_multiple_d_xdl_cshuffle< const auto kernel = kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer, typename GridwiseGemm::DsGridPointer,
...@@ -641,8 +627,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -641,8 +627,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
CDEElementwiseOperation, CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap, Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumDTensor>, ComputePtrOffsetOfStridedBatch<NumDTensor>,
has_main_loop>; has_main_loop>;
...@@ -798,7 +784,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -798,7 +784,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> || is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> ||
is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> || is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> ||
is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> || is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> ||
is_same_v<DLayout, ctc::NDHWGK>) is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::GK> ||
is_same_v<DLayout, ctc::G_K>)
{ {
const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
......
...@@ -238,10 +238,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -238,10 +238,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumPrefetch, // NumGemmKPrefetchStage NumPrefetch, // NumGemmKPrefetchStage
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -275,19 +271,19 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -275,19 +271,19 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
struct GroupedGemmBlock2ETileMap struct GroupedGemmBlock2ETileMap
{ {
using UnderlyingBlock2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
static_assert(
std::is_same<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{})),
typename GridwiseGemm::DefaultBlock2ETileMap>::value,
"Wrong! Should be the same type name");
GroupedGemmBlock2ETileMap() GroupedGemmBlock2ETileMap()
{ {
...@@ -321,7 +317,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -321,7 +317,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
return block_2_etile_map_.CheckValidity(e_grid_desc_m_n); return block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
} }
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_; Block2ETileMap block_2_etile_map_;
ck::index_t BlockStart_; ck::index_t BlockStart_;
}; };
...@@ -342,10 +338,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -342,10 +338,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map // block-to-e-tile map
GroupedGemmBlock2ETileMap block_2_etile_map_; GroupedGemmBlock2ETileMap block_2_etile_map_;
...@@ -440,7 +435,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -440,7 +435,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
block_2_etile_map)) block_2_etile_map))
{ {
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock; ds_grid_desc_mblock_mperblock_nblock_nperblock;
static_for<0, NumDTensor, 1>{}([&](auto j) { static_for<0, NumDTensor, 1>{}([&](auto j) {
......
...@@ -92,6 +92,12 @@ struct GNDHWC : public BaseTensorLayout ...@@ -92,6 +92,12 @@ struct GNDHWC : public BaseTensorLayout
static constexpr const char* name = "GNDHWC"; static constexpr const char* name = "GNDHWC";
}; };
// for input bias
struct GC : public BaseTensorLayout
{
static constexpr const char* name = "GC";
};
// input tensor // input tensor
// packed NWGC/NHWGC/NDHWGC // packed NWGC/NHWGC/NDHWGC
struct NWGC : public BaseTensorLayout struct NWGC : public BaseTensorLayout
...@@ -126,6 +132,12 @@ struct G_NDHW_C : public BaseTensorLayout ...@@ -126,6 +132,12 @@ struct G_NDHW_C : public BaseTensorLayout
static constexpr const char* name = "G_NDHW_C"; static constexpr const char* name = "G_NDHW_C";
}; };
// for input bias
struct G_C : public BaseTensorLayout
{
static constexpr const char* name = "G_C";
};
// weight tensor // weight tensor
// packed KCX/KCYX/KCZYX // packed KCX/KCYX/KCZYX
struct KCX : public BaseTensorLayout struct KCX : public BaseTensorLayout
...@@ -296,6 +308,12 @@ struct GNDHWK : public BaseTensorLayout ...@@ -296,6 +308,12 @@ struct GNDHWK : public BaseTensorLayout
static constexpr const char* name = "GNDHWK"; static constexpr const char* name = "GNDHWK";
}; };
// for output bias
struct GK : public BaseTensorLayout
{
static constexpr const char* name = "GK";
};
// output tensor // output tensor
// packed NWGK/NHWGK/NDHWGK // packed NWGK/NHWGK/NDHWGK
struct NWGK : public BaseTensorLayout struct NWGK : public BaseTensorLayout
...@@ -330,6 +348,12 @@ struct G_NDHW_K : public BaseTensorLayout ...@@ -330,6 +348,12 @@ struct G_NDHW_K : public BaseTensorLayout
static constexpr const char* name = "G_NDHW_K"; static constexpr const char* name = "G_NDHW_K";
}; };
// for output bias
struct G_K : public BaseTensorLayout
{
static constexpr const char* name = "G_K";
};
// K-reduced output tensor (packed) // K-reduced output tensor (packed)
struct GNW : public BaseTensorLayout struct GNW : public BaseTensorLayout
{ {
......
...@@ -35,10 +35,6 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype ...@@ -35,10 +35,6 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation, InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -166,6 +162,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -166,6 +162,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
} }
// A desc for source in blockwise copy // A desc for source in blockwise copy
template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{ {
...@@ -182,6 +179,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -182,6 +179,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
} }
// B desc for source in blockwise copy // B desc for source in blockwise copy
template <typename BGridDesc_N_K>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{ {
...@@ -198,9 +196,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -198,9 +196,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
} }
// E desc for destination in blockwise copy // E desc for destination in blockwise copy
template <typename EGridDescriptor_M_N> template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( __host__ __device__ static constexpr auto
const EGridDescriptor_M_N& e_grid_desc_m_n) MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
{ {
const auto M = e_grid_desc_m_n.GetLength(I0); const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1); const auto N = e_grid_desc_m_n.GetLength(I1);
...@@ -219,10 +217,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -219,10 +217,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
} }
// Ds desc for source in blockwise copy // Ds desc for source in blockwise copy
template <typename DsGridDescriptor_M_N> template <typename DsGridDesc_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
const DsGridDescriptor_M_N& ds_grid_desc_m_n)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -232,6 +229,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -232,6 +229,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
} }
// return block_id to E matrix tile idx (m0, n0) mapping // return block_id to E matrix tile idx (m0, n0) mapping
template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{ {
...@@ -240,7 +238,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -240,7 +238,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2ETileMap> template <typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
typename Block2ETileMap>
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k, const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n, const DsGridDesc_M_N& ds_grid_desc_m_n,
...@@ -314,23 +316,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -314,23 +316,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
using DefaultAGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using DefaultBGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
using DsGridPointer = decltype(MakeDsGridPointer()); using DsGridPointer = decltype(MakeDsGridPointer());
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap> typename Block2ETileMap>
__device__ static void Run(const ABDataType* __restrict__ p_a_grid, __device__ static void Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
...@@ -342,9 +334,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -342,9 +334,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const CDEElementwiseOperation& cde_element_op, const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
......
...@@ -16,6 +16,7 @@ namespace tensor_operation { ...@@ -16,6 +16,7 @@ namespace tensor_operation {
template <index_t NDimSpatial, device::ConvolutionForwardSpecialization ConvForwardSpecialization> template <index_t NDimSpatial, device::ConvolutionForwardSpecialization ConvForwardSpecialization>
struct TransformConvFwdToGemm struct TransformConvFwdToGemm
{ {
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
template <typename ALayout, template <typename ALayout,
...@@ -864,6 +865,29 @@ struct TransformConvFwdToGemm ...@@ -864,6 +865,29 @@ struct TransformConvFwdToGemm
return out_gemmm_gemmn_desc; return out_gemmm_gemmn_desc;
} }
// for output bias
template <typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::GK> ||
is_same_v<CLayout, tensor_layout::convolution::G_K>,
bool>::type = false>
static auto
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */)
{
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3,
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, I1));
return out_gemmm_gemmn_desc;
}
}; };
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_IGNORE_HPP #pragma once
#define CK_IGNORE_HPP
// https://en.cppreference.com/w/cpp/utility/tuple/ignore // https://en.cppreference.com/w/cpp/utility/tuple/ignore
...@@ -21,4 +20,3 @@ struct ignore_t ...@@ -21,4 +20,3 @@ struct ignore_t
inline constexpr detail::ignore_t ignore; inline constexpr detail::ignore_t ignore;
} // namespace ck } // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment