Commit 0b997ce4 authored by Chao Liu's avatar Chao Liu
Browse files

adding conv multiple D

parent 69d323de
...@@ -16,7 +16,7 @@ using S = ck::Sequence<Is...>; ...@@ -16,7 +16,7 @@ using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough; using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::UnaryConvert;
static constexpr auto ConvFwdDefault = static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
...@@ -48,18 +48,18 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc ...@@ -48,18 +48,18 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc
2, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1 8, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM true, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim 2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_K1 8, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN true, // BBlockLdsExtraN
7, // CThreadTransferSrcDstVectorDim 7, // CThreadTransferSrcDstVectorDim
1>; // CThreadTransferDstScalarPerVector 1>; // CThreadTransferDstScalarPerVector
#else #else
using CShuffleDataType = float; using CShuffleDataType = ck::half_t;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvNDFwdInstance = using DeviceConvNDFwdInstance =
...@@ -69,37 +69,40 @@ using DeviceConvNDFwdInstance = ...@@ -69,37 +69,40 @@ using DeviceConvNDFwdInstance =
WeiDataType, // WeiDataType, //
AccDataType, // AccDataType, //
CShuffleDataType, // CShuffleDataType, //
ck::Tuple<>, ck::Tuple<>, //
OutDataType, // OutDataType, //
InElementOp, // Input Elementwise Operation InElementOp, // Input Elementwise Operation
WeiElementOp, // Weights Elementwise Operation WeiElementOp, // Weights Elementwise Operation
OutElementOp, // Output Elementwise Operation OutElementOp, // Output Elementwise Operation
ConvFwdDefault, // ConvForwardSpecialization ConvFwdDefault, // ConvForwardSpecialization
256, // BlockSize 1, //
128, // MPerBlock 256, // BlockSize
256, // NPerBlock 128, // MPerBlock
4, // K0PerBlock 256, // NPerBlock
8, // K1 32, // KPerBlock
32, // MPerXdl 8, // K1
32, // NPerXdl 32, // MPerXdl
2, // MXdlPerWave 32, // NPerXdl
4, // NXdlPerWave 2, // MXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 4, // NXdlPerWave
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
2, // ABlockTransferSrcVectorDim S<1, 0, 2>, // ABlockTransferSrcAccessOrder
8, // ABlockTransferSrcScalarPerVector 2, // ABlockTransferSrcVectorDim
8, // ABlockTransferDstScalarPerVector_K1 8, // ABlockTransferSrcScalarPerVector
true, // ABlockLdsAddExtraM 8, // ABlockTransferDstScalarPerVector_K1
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 1, // ABlockLdsExtraM
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
2, // BBlockTransferSrcVectorDim S<1, 0, 2>, // BBlockTransferSrcAccessOrder
8, // BBlockTransferSrcScalarPerVector 2, // BBlockTransferSrcVectorDim
8, // BBlockTransferDstScalarPerVector_K1 8, // BBlockTransferSrcScalarPerVector
true, // BBlockLdsAddExtraN 8, // BBlockTransferDstScalarPerVector_K1
7, // CThreadTransferSrcDstVectorDim 1, // BBlockLdsExtraN
1>; // CThreadTransferDstScalarPerVector 1,
1,
S<1, 32, 1, 8>,
8>;
#endif #endif
int main(int argc, char* argv[]) int main(int argc, char* argv[])
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" #include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/device_utility/kernel_launch.hpp"
...@@ -23,6 +24,73 @@ namespace ck { ...@@ -23,6 +24,73 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatDsPointer,
typename FloatE,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_multiple_d_xdl_cshuffle(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_etile_map;
#endif
}
} // namespace
// //
// @brief Device Convolution operation. // @brief Device Convolution operation.
// //
...@@ -39,7 +107,7 @@ namespace device { ...@@ -39,7 +107,7 @@ namespace device {
// 3D: // 3D:
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] // out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
// //
template <ck::index_t NDimSpatial, template <index_t NDimSpatial,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename AccDataType, typename AccDataType,
...@@ -50,31 +118,35 @@ template <ck::index_t NDimSpatial, ...@@ -50,31 +118,35 @@ template <ck::index_t NDimSpatial,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize, index_t NumGemmKPrefetchStage,
ck::index_t MPerBlock, index_t BlockSize,
ck::index_t NPerBlock, index_t MPerBlock,
ck::index_t K0PerBlock, index_t NPerBlock,
ck::index_t K1, index_t KPerBlock,
ck::index_t MPerXDL, index_t K1,
ck::index_t NPerXDL, index_t MPerXDL,
ck::index_t MXdlPerWave, index_t NPerXDL,
ck::index_t NXdlPerWave, index_t MXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1, index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1, index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsAddExtraM, index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1, index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsAddExtraN, index_t BBlockLdsExtraN,
ck::index_t CThreadTransferSrcDstVectorDim, index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CThreadTransferDstScalarPerVector> index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
: public DeviceConvFwd<NDimSpatial, : public DeviceConvFwd<NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1, ck::tuple_element_t<NDimSpatial - 1,
...@@ -96,8 +168,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -96,8 +168,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation> CDEElementwiseOperation>
{ {
using DeviceOp = DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle; using DeviceOp = DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -109,12 +184,12 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -109,12 +184,12 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
static auto GetWeightTensorDescriptor(ck::index_t gemm_n, ck::index_t gemm_k) static auto GetWeightTensorDescriptor(ck::index_t gemm_n, ck::index_t gemm_k)
{ {
const ck::index_t gemm_k0 = gemm_k / GemmK1Number; const ck::index_t gemm_k0 = gemm_k / GemmK1Number;
const auto wei_k_yxc_grid_desc = const auto wei_k_yxe_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_n, gemm_k)); make_naive_tensor_descriptor_packed(make_tuple(gemm_n, gemm_k));
// wei_gemmk0_gemmn_gemmk1_grid_desc // wei_gemmk0_gemmn_gemmk1_grid_desc
return transform_tensor_descriptor( return transform_tensor_descriptor(
wei_k_yxc_grid_desc, wei_k_yxe_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_pass_through_transform(gemm_n)), make_pass_through_transform(gemm_n)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
...@@ -149,6 +224,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -149,6 +224,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const std::vector<ck::index_t>& input_left_pads, const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads) const std::vector<ck::index_t>& input_right_pads)
{ {
const ck::index_t gemm_k0 = gemm_k / GemmK1Number; const ck::index_t gemm_k0 = gemm_k / GemmK1Number;
const index_t Wi = input_spatial_lengths[0]; const index_t Wi = input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0]; const index_t Wo = output_spatial_lengths[0];
...@@ -171,11 +247,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -171,11 +247,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
const auto in_n_wi_c_grid_desc = const auto in_n_wi_e_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor( const auto in_n_wo_e_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc, in_n_wi_e_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)), make_pass_through_transform(C)),
...@@ -183,7 +259,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -183,7 +259,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_n_wo_c_grid_desc, in_n_wo_e_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_merge_transform(make_tuple(N, Wo))), make_merge_transform(make_tuple(N, Wo))),
make_tuple(Sequence<2>{}, Sequence<0, 1>{}), make_tuple(Sequence<2>{}, Sequence<0, 1>{}),
...@@ -205,19 +281,19 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -205,19 +281,19 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const index_t InLeftPadW = input_left_pads[0]; const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0]; const index_t InRightPadW = input_right_pads[0];
const auto in_n_wi_c_grid_desc = const auto in_n_wi_e_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( const auto in_n_wip_e_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc, in_n_wi_e_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( const auto in_n_x_wo_e_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc, in_n_wip_e_grid_desc,
make_tuple( make_tuple(
make_pass_through_transform(N), make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
...@@ -226,7 +302,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -226,7 +302,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemmk_gemmmraw_grid_desc = const auto in_gemmk_gemmmraw_grid_desc =
transform_tensor_descriptor(in_n_x_wo_c_grid_desc, transform_tensor_descriptor(in_n_x_wo_e_grid_desc,
make_tuple(make_merge_transform(make_tuple(X, C)), make_tuple(make_merge_transform(make_tuple(X, C)),
make_merge_transform(make_tuple(N, Wo))), make_merge_transform(make_tuple(N, Wo))),
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}),
...@@ -291,11 +367,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -291,11 +367,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
const auto in_n_hi_wi_c_grid_desc = const auto in_n_hi_wi_e_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( const auto in_n_ho_wo_e_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc, in_n_hi_wi_e_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
...@@ -304,7 +380,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -304,7 +380,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_n_ho_wo_c_grid_desc, in_n_ho_wo_e_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_merge_transform(make_tuple(N, Ho, Wo))), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
...@@ -333,11 +409,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -333,11 +409,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const index_t InRightPadH = input_right_pads[0]; const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1]; const index_t InRightPadW = input_right_pads[1];
const auto in_n_hi_wi_c_grid_desc = const auto in_n_hi_wi_e_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_n_hip_wip_e_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc, in_n_hi_wi_e_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi, InLeftPadW, InRightPadW),
...@@ -345,8 +421,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -345,8 +421,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( const auto in_n_y_ho_x_wo_e_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc, in_n_hip_wip_e_grid_desc,
make_tuple( make_tuple(
make_pass_through_transform(N), make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
...@@ -356,7 +432,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -356,7 +432,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmmraw_grid_desc = const auto in_gemmk_gemmmraw_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, transform_tensor_descriptor(in_n_y_ho_x_wo_e_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)), make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
...@@ -372,6 +448,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -372,6 +448,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
// in_gemmk0_gemmm_gemmk1_grid_desc // in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc, in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(gemm_k0), make_tuple(make_pass_through_transform(gemm_k0),
make_right_pad_transform(gemm_m, gemm_m_pad), make_right_pad_transform(gemm_m, gemm_m_pad),
make_pass_through_transform(GemmK1Number)), make_pass_through_transform(GemmK1Number)),
...@@ -424,11 +501,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -424,11 +501,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
const auto in_n_di_hi_wi_c_grid_desc = const auto in_n_di_hi_wi_e_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor( const auto in_n_do_ho_wo_e_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc, in_n_di_hi_wi_e_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
...@@ -440,9 +517,10 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -440,9 +517,10 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_grid_desc, in_n_do_ho_wo_e_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_merge_transform(make_tuple(N, Do, Ho, Wo))), make_merge_transform(make_tuple(N, Do, Ho, Wo))),
make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{}), make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -473,11 +551,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -473,11 +551,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const index_t InRightPadH = input_right_pads[1]; const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2]; const index_t InRightPadW = input_right_pads[2];
const auto in_n_di_hi_wi_c_grid_desc = const auto in_n_di_hi_wi_e_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_n_hip_wip_e_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc, in_n_di_hi_wi_e_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD), make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi, InLeftPadH, InRightPadH),
...@@ -488,8 +566,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -488,8 +566,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( const auto in_n_z_do_y_ho_x_wo_e_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc, in_n_hip_wip_e_grid_desc,
make_tuple( make_tuple(
make_pass_through_transform(N), make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
...@@ -505,7 +583,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -505,7 +583,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
Sequence<7>{})); Sequence<7>{}));
const auto in_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( const auto in_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc, in_n_z_do_y_ho_x_wo_e_grid_desc,
make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)),
make_merge_transform(make_tuple(N, Do, Ho, Wo))), make_merge_transform(make_tuple(N, Do, Ho, Wo))),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
...@@ -547,7 +625,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -547,7 +625,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
} }
static auto static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths, std::vector<ck::index_t> input_spatial_lengths,
...@@ -568,7 +646,6 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -568,7 +646,6 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
// C = A^T*B
// A: // A:
const auto in_gemmk0_gemmm_gemmk1_grid_desc = const auto in_gemmk0_gemmm_gemmk1_grid_desc =
GetInputTensorDescriptor<NDimSpatial>(N, GetInputTensorDescriptor<NDimSpatial>(N,
...@@ -585,7 +662,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -585,7 +662,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
input_right_pads); input_right_pads);
// B: // B:
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = GetWeightTensorDescriptor(GemmN, GemmK); const auto wei_gemmk0_gemmn_gemmk1_grid_desc = GetWeightTensorDescriptor(GemmN, GemmK);
// C: // E:
const auto out_gemmm_gemmn_grid_desc = GetOutputTensorDescriptor(GemmMRaw, GemmN, GemmMPad); const auto out_gemmm_gemmn_grid_desc = GetOutputTensorDescriptor(GemmMRaw, GemmN, GemmMPad);
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
...@@ -594,74 +671,84 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -594,74 +671,84 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
} }
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetABCGridDesc() static auto GetABEGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( return MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}); 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1});
} }
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc() static auto GetABEGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( return MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
} }
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc() static auto GetABEGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( return MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); 1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
} }
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>()); using ABEGridDescs = decltype(GetABEGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>; using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(ABEGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>; using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(ABEGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>; using EGridDesc_M_N = remove_cvref_t<decltype(ABEGridDescs{}[I2])>;
using Block2CTileMap = BlockToCTileMap_M00_N0_M01<MPerBlock, NPerBlock, CGridDesc_M_N>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CShuffleDataType,
DsDataType,
EDataType, EDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
EGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, KPerBlock,
K1,
K1,
MPerXDL, MPerXDL,
NPerXDL, NPerXDL,
K1,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
2, // ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_AK1,
false, // AThreadTransferSrcResetCoordinateAfterRun, false,
ABlockLdsAddExtraM, ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
2, // BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_BK1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false,
BBlockLdsAddExtraN, BBlockLdsExtraN,
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, CShuffleMXdlPerWavePerShuffle,
7, // CThreadTransferSrcDstVectorDim, CShuffleNXdlPerWavePerShuffle,
CThreadTransferDstScalarPerVector>; CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
#if 0
using Block2ETileMap = BlockToCTileMap_M00_N0_M01<MPerBlock, NPerBlock, EGridDesc_M_N>;
#else
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
#endif
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -682,17 +769,18 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -682,17 +769,18 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
AElementwiseOperation in_element_op, AElementwiseOperation in_element_op,
BElementwiseOperation wei_element_op, BElementwiseOperation wei_element_op,
CDEElementwiseOperation out_element_op) CDEElementwiseOperation out_element_op)
: p_a_grid_{p_in_grid}, : p_a_grid_{static_cast<const ADataType*>(p_in_grid)},
p_b_grid_{p_wei_grid}, p_b_grid_{static_cast<const BDataType*>(p_wei_grid)},
p_c_grid_{p_out_grid}, p_ds_grid_{}, // FIXME
a_grid_desc_k0_m_k1_{}, p_e_grid_{static_cast<EDataType*>(p_out_grid)},
b_grid_desc_k0_n_k1_{}, a_grid_desc_ak0_m_ak1_{},
c_grid_desc_m_n_{}, b_grid_desc_bk0_n_bk1_{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, e_grid_desc_m_n_{},
block_2_ctile_map_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
in_element_op_{in_element_op}, block_2_etile_map_{},
wei_element_op_{wei_element_op}, a_element_op_{in_element_op},
out_element_op_{out_element_op}, b_element_op_{wei_element_op},
cde_element_op_{out_element_op},
Conv_N_{N}, Conv_N_{N},
Conv_K_{K}, Conv_K_{K},
Conv_C_{C}, Conv_C_{C},
...@@ -702,7 +790,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -702,7 +790,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
input_right_pads_{input_right_pads} input_right_pads_{input_right_pads}
{ {
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, DeviceOp::MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N,
K, K,
C, C,
input_spatial_lengths, input_spatial_lengths,
...@@ -713,35 +801,50 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -713,35 +801,50 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
input_left_pads, input_left_pads,
input_right_pads); input_right_pads);
a_grid_desc_k0_m_k1_ = descs[I0]; a_grid_desc_ak0_m_ak1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1]; b_grid_desc_bk0_n_bk1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2]; e_grid_desc_m_n_ = descs[I2];
block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_}; block_2_etile_map_ = Block2ETileMap{e_grid_desc_m_n_};
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_k0_n_k1_, b_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_, e_grid_desc_m_n_,
block_2_ctile_map_)) block_2_etile_map_))
{ {
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
} }
} }
// private: // private:
// pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
EDataType* p_c_grid_; typename GridwiseGemm::DsGridPointer p_ds_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; EDataType* p_e_grid_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; // tensor descriptors
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
Block2CTileMap block_2_ctile_map_; StaticallyIndexedArray<
AElementwiseOperation in_element_op_; typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
BElementwiseOperation wei_element_op_; NumDTensor>
CDEElementwiseOperation out_element_op_; ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
index_t Conv_N_; index_t Conv_N_;
index_t Conv_K_; index_t Conv_K_;
...@@ -761,99 +864,84 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -761,99 +864,84 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
{ {
#if 0 #if 0
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_ak0_m_ak1_{" << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) std::cout << "arg.b_grid_desc_bk0_n_bk1_{" << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " << ", " << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif #endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.block_2_ctile_map_)) arg.block_2_etile_map_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const auto K = const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
float ave_time = 0; auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle<
{
const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType, EDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
Block2CTileMap, DeviceOp::AGridDesc_AK0_M_AK1,
true>; DeviceOp::BGridDesc_BK0_N_BK1,
ck::StaticallyIndexedArray<
ave_time = launch_and_time_kernel(stream_config, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
kernel, NumDTensor>,
dim3(grid_size), typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
dim3(BlockSize), Block2ETileMap,
0, has_main_loop>;
arg.p_a_grid_,
arg.p_b_grid_, return launch_and_time_kernel(stream_config,
arg.p_c_grid_, kernel,
arg.a_grid_desc_k0_m_k1_, dim3(grid_size),
arg.b_grid_desc_k0_n_k1_, dim3(BlockSize),
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, 0,
arg.in_element_op_, arg.p_a_grid_,
arg.wei_element_op_, arg.p_b_grid_,
arg.out_element_op_, arg.p_ds_grid_,
arg.block_2_ctile_map_); arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_);
};
float avg_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
avg_time = launch_kernel(integral_constant<bool, true>{});
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< avg_time = launch_kernel(integral_constant<bool, false>{});
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
EDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
Block2CTileMap,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.in_element_op_,
arg.wei_element_op_,
arg.out_element_op_,
arg.block_2_ctile_map_);
} }
return ave_time; return avg_time;
} }
float Run(const BaseArgument* p_arg, float Run(const BaseArgument* p_arg,
...@@ -863,12 +951,6 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -863,12 +951,6 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
} }
}; };
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx908") if(ck::get_device_name() == "gfx908")
...@@ -892,12 +974,12 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -892,12 +974,12 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
return false; return false;
} }
// Input tensors can't be bigger than 2GB each. // tensors can't be bigger than 2GB each.
constexpr ck::long_index_t GB2 = (ck::long_index_t{1} << 31); constexpr ck::long_index_t GB2 = (ck::long_index_t{1} << 31);
if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) > GB2 || if(arg.a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() * sizeof(ADataType) > GB2 ||
arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) > GB2 || arg.b_grid_desc_bk0_n_bk1_.GetElementSpaceSize() * sizeof(BDataType) > GB2 ||
arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) > GB2) arg.e_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) > GB2)
{ {
return false; return false;
} }
...@@ -937,17 +1019,17 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -937,17 +1019,17 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
return false; return false;
} }
// vector store C matrix into global memory // vector store D/E matrix into global memory
if(!(arg.Conv_K_ % CThreadTransferDstScalarPerVector == 0)) if(!(arg.Conv_K_ % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{ {
return false; return false;
} }
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.block_2_ctile_map_); arg.block_2_etile_map_);
} }
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
...@@ -1043,7 +1125,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -1043,7 +1125,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -618,18 +618,18 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -618,18 +618,18 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
arg.block_2_etile_map_); arg.block_2_etile_map_);
}; };
float ave_time = 0; float avg_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); avg_time = launch_kernel(integral_constant<bool, true>{});
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); avg_time = launch_kernel(integral_constant<bool, false>{});
} }
return ave_time; return avg_time;
} }
// polymorphic // polymorphic
......
...@@ -12,16 +12,47 @@ namespace element_wise { ...@@ -12,16 +12,47 @@ namespace element_wise {
struct PassThrough struct PassThrough
{ {
template <typename T> template <typename Y, typename X>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<double, double>(double& y, const double& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || y = x;
is_same<T, half_t>::value || is_same<T, bhalf_t>::value || }
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
y = x; y = x;
}; }
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = x;
}
};
struct UnaryConvert
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
y = type_convert<Y>(x);
}
}; };
struct Scale struct Scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment