Unverified Commit f63a23ac authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

[MIOpen Downstream] Initial MIOpen integration (#52)

* update online kernel wrapper bundle all descriptors in a tuple

* change __CONSTANT__ to CONSTANT

* rename

* adding tuning

* added IsValidCompileParameter

* reorginze

* adding tunable for fp16 and int8

* fix kernel compile warning and bug fixes

* suppress warning about cast CONSTANT (address space 4) pointer

* fix building issue
parent 12649254
#include "common_header.hpp" #include "common_header.hpp"
#include "type_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_contraction_v1r2.hpp" #include "gridwise_dynamic_contraction_dlops_v1r2.hpp"
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" #include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
using namespace ck; using namespace ck;
using FloatAB = typename get_type_from_type_id<static_cast<char>(CK_PARAM_IN_WEI_DATATYPE)>::type; constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
using FloatAcc = typename get_type_from_type_id<static_cast<char>(CK_PARAM_ACC_DATATYPE)>::type; constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
using FloatC = typename get_type_from_type_id<static_cast<char>(CK_PARAM_OUT_DATATYPE)>::type; constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
constexpr index_t BlockSize = CK_PARAM_BlockSize; constexpr index_t BlockSize = CK_PARAM_BlockSize;
constexpr auto GN0 = Number<CK_PARAM_GN0>{}; constexpr auto GN0 = Number<CK_PARAM_GN0>{};
constexpr auto GK1 = Number<CK_PARAM_GK1>{}; constexpr auto GK1 = Number<CK_PARAM_GK1>{};
constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11; constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11;
constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11; constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11;
constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock; constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock;
constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11;
constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11; constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11;
constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread; constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11;
constexpr index_t BM10BN10ThreadClusterBM100 = CK_PARAM_BM10BN10ThreadClusterBM100; constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread;
constexpr index_t BM10BN10ThreadClusterBN100 = CK_PARAM_BM10BN10ThreadClusterBN100;
constexpr index_t BM10BN10ThreadClusterBM101 = CK_PARAM_BM10BN10ThreadClusterBM101; using BM10BN10ThreadClusterBM10Xs = Sequence<CK_PARAM_BM10BN10ThreadClusterBM10Xs>;
constexpr index_t BM10BN10ThreadClusterBN101 = CK_PARAM_BM10BN10ThreadClusterBN101; using BM10BN10ThreadClusterBN10Xs = Sequence<CK_PARAM_BM10BN10ThreadClusterBN10Xs>;
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 =
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1>; Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1>;
...@@ -55,29 +58,26 @@ using CThreadTransferSrcDstAccessOrder = Sequence<3, 4, 5, 0, 1, 2> ...@@ -55,29 +58,26 @@ using CThreadTransferSrcDstAccessOrder = Sequence<3, 4, 5, 0, 1, 2>
constexpr index_t CThreadTransferSrcDstVectorDim = 5; constexpr index_t CThreadTransferSrcDstVectorDim = 5;
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP); constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HasMainKBlockLoop);
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP); constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HasDoubleTailKBlockLoop);
extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw_prepare( extern "C" __global__ void
index_t N, dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
index_t C, index_t C,
index_t Hi, index_t Hi,
index_t Wi, index_t Wi,
index_t K, index_t K,
index_t Y, index_t Y,
index_t X, index_t X,
index_t ConvStrideH, index_t ConvStrideH,
index_t ConvStrideW, index_t ConvStrideW,
index_t ConvDilationH, index_t ConvDilationH,
index_t ConvDilationW, index_t ConvDilationW,
index_t InLeftPadH, index_t InLeftPadH,
index_t InLeftPadW, index_t InLeftPadW,
index_t InRightPadH, index_t InRightPadH,
index_t InRightPadW, index_t InRightPadW,
void* p_a_grid_desc_gk0_gm0_gm10_gm11_gk1, void* p_desc_tuple)
void* p_b_grid_desc_gk0_gn0_gn10_gn11_gk1,
void* p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
void* p_c_grid_block_cluster_blockid_to_gm10_gn10)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -160,12 +160,12 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k ...@@ -160,12 +160,12 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
using GridwiseContraction = using GridwiseContraction =
GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
AGridDesc_GK0_GM0_GM1_GK1, AGridDesc_GK0_GM0_GM1_GK1,
BGridDesc_GK0_GN0_GN1_GK1, BGridDesc_GK0_GN0_GN1_GK1,
CGridDesc_GM0_GM1_GN0_GN1, CGridDesc_GM0_GM1_GN0_GN1,
...@@ -175,10 +175,8 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k ...@@ -175,10 +175,8 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k
BM1PerThreadBM11, BM1PerThreadBM11,
BN1PerThreadBN11, BN1PerThreadBN11,
BK0PerThread, BK0PerThread,
BM10BN10ThreadClusterBM100, BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN100, BM10BN10ThreadClusterBN10Xs,
BM10BN10ThreadClusterBM101,
BM10BN10ThreadClusterBN101,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
...@@ -202,47 +200,36 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k ...@@ -202,47 +200,36 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>; BGridMoveSliceWindowIteratorHacks>;
auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1);
auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 =
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1);
auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
c_grid_desc_gm0_gm1_gn0_gn1);
auto c_grid_block_cluster_blockid_to_gm10_gn10 =
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
c_grid_desc_gm0_gm1_gn0_gn1);
if(hipThreadIdx_x == 0)
{ {
*static_cast<decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1)*>( auto desc_tuple =
p_a_grid_desc_gk0_gm0_gm10_gm11_gk1) = a_grid_desc_gk0_gm0_gm10_gm11_gk1; make_tuple(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
*static_cast<decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1)*>( a_grid_desc_gk0_gm0_gm1_gk1),
p_b_grid_desc_gk0_gn0_gn10_gn11_gk1) = b_grid_desc_gk0_gn0_gn10_gn11_gk1; GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
*static_cast<decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1)*>( b_grid_desc_gk0_gn0_gn1_gk1),
p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1) = c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1; GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
*static_cast<decltype(c_grid_block_cluster_blockid_to_gm10_gn10)*>( c_grid_desc_gm0_gm1_gn0_gn1),
p_c_grid_block_cluster_blockid_to_gm10_gn10) = GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
c_grid_block_cluster_blockid_to_gm10_gn10; c_grid_desc_gm0_gm1_gn0_gn1));
};
*static_cast<decltype(desc_tuple)*>(p_desc_tuple) = desc_tuple;
}
}; };
extern "C" __global__ void extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_grid_desc_gk0_gm0_gm10_gm11_gk1, const void CONSTANT* p_desc_tuple)
const void __CONSTANT__* p_b_grid_desc_gk0_gn0_gn10_gn11_gk1,
const void __CONSTANT__* p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
const void __CONSTANT__* p_c_grid_block_cluster_blockid_to_gm10_gn10)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_hi_wi_desc = constexpr auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
...@@ -316,12 +303,12 @@ extern "C" __global__ void ...@@ -316,12 +303,12 @@ extern "C" __global__ void
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
using GridwiseContraction = using GridwiseContraction =
GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
AGridDesc_GK0_GM0_GM1_GK1, AGridDesc_GK0_GM0_GM1_GK1,
BGridDesc_GK0_GN0_GN1_GK1, BGridDesc_GK0_GN0_GN1_GK1,
CGridDesc_GM0_GM1_GN0_GN1, CGridDesc_GM0_GM1_GN0_GN1,
...@@ -331,10 +318,8 @@ extern "C" __global__ void ...@@ -331,10 +318,8 @@ extern "C" __global__ void
BM1PerThreadBM11, BM1PerThreadBM11,
BN1PerThreadBN11, BN1PerThreadBN11,
BK0PerThread, BK0PerThread,
BM10BN10ThreadClusterBM100, BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN100, BM10BN10ThreadClusterBN10Xs,
BM10BN10ThreadClusterBM101,
BM10BN10ThreadClusterBN101,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
...@@ -371,18 +356,23 @@ extern "C" __global__ void ...@@ -371,18 +356,23 @@ extern "C" __global__ void
decltype(GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( decltype(GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
c_grid_desc_gm0_gm1_gn0_gn1)); c_grid_desc_gm0_gm1_gn0_gn1));
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = using DescTuple = decltype(make_tuple(AGridDesc_GK0_GM0_GM10_GM11_GK1{},
*reinterpret_cast<const AGridDesc_GK0_GM0_GM10_GM11_GK1*>( BGridDesc_GK0_GN0_GN10_GN11_GK1{},
(const void*)p_a_grid_desc_gk0_gm0_gm10_gm11_gk1); CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{},
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = CGridBlockCluster_BlockId_To_GM10_GN10{}));
*reinterpret_cast<const BGridDesc_GK0_GN0_GN10_GN11_GK1*>(
(const void*)p_b_grid_desc_gk0_gn0_gn10_gn11_gk1); const auto desc_tuple = *reinterpret_cast<const DescTuple*>(
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = #pragma clang diagnostic push
*reinterpret_cast<const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1*>( #pragma clang diagnostic ignored "-Wold-style-cast"
(const void*)p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1); // TODO: how to cast?
const auto c_grid_block_cluster_blockid_to_gm10_gn10 = (const void*)p_desc_tuple
*reinterpret_cast<const CGridBlockCluster_BlockId_To_GM10_GN10*>( #pragma clang diagnostic pop
(const void*)p_c_grid_block_cluster_blockid_to_gm10_gn10); );
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0];
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1];
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = desc_tuple[I2];
const auto c_grid_block_cluster_blockid_to_gm10_gn10 = desc_tuple[I3];
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......
...@@ -12,17 +12,17 @@ ...@@ -12,17 +12,17 @@
#include "conv_common.hpp" #include "conv_common.hpp"
#include "host_conv.hpp" #include "host_conv.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1 #define USE_DYNAMIC_MODE 1
#define USE_CONV_FWD_V4R4_NCHW 1 #define USE_CONV_FWD_V4R4_NCHW 1
#define USE_CONV_FWD_V4R4R2_NHWC 0 #define USE_CONV_FWD_V4R4R2_NHWC 1
#define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V6R1_NCHW 1
#define USE_CONV_FWD_V5R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 #define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0 #define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
...@@ -301,19 +301,20 @@ int main(int argc, char* argv[]) ...@@ -301,19 +301,20 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw(); const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw<in_data_t,
acc_data_t, acc_data_t,
out_data_t>(tmp[I0], out_data_t>(
tmp[I1], tmp[I0],
tmp[I2], tmp[I1],
tmp[I3], tmp[I2],
tmp[I4], tmp[I3],
tmp[I5], tmp[I4],
tmp[I6], tmp[I5],
in, tmp[I6],
wei, in,
out_device, wei,
nrepeat); out_device,
nrepeat);
} }
#endif #endif
...@@ -327,9 +328,9 @@ int main(int argc, char* argv[]) ...@@ -327,9 +328,9 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nhwc(); const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk<in_data_t,
acc_data_t, acc_data_t,
out_data_t>( out_data_t>(
tmp[I0], tmp[I0],
tmp[I1], tmp[I1],
tmp[I2], tmp[I2],
...@@ -354,19 +355,20 @@ int main(int argc, char* argv[]) ...@@ -354,19 +355,20 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw(); const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw<in_data_t,
acc_data_t, acc_data_t,
out_data_t>(tmp[I0], out_data_t>(
tmp[I1], tmp[I0],
tmp[I2], tmp[I1],
tmp[I3], tmp[I2],
tmp[I4], tmp[I3],
tmp[I5], tmp[I4],
tmp[I6], tmp[I5],
in, tmp[I6],
wei, in,
out_device, wei,
nrepeat); out_device,
nrepeat);
} }
#endif #endif
...@@ -380,20 +382,21 @@ int main(int argc, char* argv[]) ...@@ -380,20 +382,21 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw(); const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw<in_data_t,
16, 16,
acc_data_t, acc_data_t,
out_data_t>(tmp[I0], out_data_t>(
tmp[I1], tmp[I0],
tmp[I2], tmp[I1],
tmp[I3], tmp[I2],
tmp[I4], tmp[I3],
tmp[I5], tmp[I4],
tmp[I6], tmp[I5],
in, tmp[I6],
wei, in,
out_device, wei,
nrepeat); out_device,
nrepeat);
} }
#endif #endif
......
...@@ -264,7 +264,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx ...@@ -264,7 +264,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc), decltype(in_gemmm_gemmn_grid_desc),
......
...@@ -236,7 +236,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k ...@@ -236,7 +236,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc), decltype(in_gemmm_gemmn_grid_desc),
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_gemm_v1r2.hpp" #include "driver_dynamic_gemm_dlops_v1r2.hpp"
template <typename TInWei, template <typename TInWei,
typename TAcc, typename TAcc,
...@@ -14,7 +14,7 @@ template <typename TInWei, ...@@ -14,7 +14,7 @@ template <typename TInWei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths, const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths, const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths, const OutLengths& out_n_k_ho_wo_lengths,
...@@ -142,12 +142,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -142,12 +142,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = driver_dynamic_gemm_v1r2< float ave_time = driver_dynamic_gemm_dlops_v1r2<
BlockSize, BlockSize,
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk_gemmm_grid_desc), decltype(wei_gemmk_gemmm_grid_desc),
decltype(in_gemmk_gemmn_grid_desc), decltype(in_gemmk_gemmn_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(out_gemmm_gemmn_grid_desc),
......
...@@ -220,7 +220,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -220,7 +220,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(descs[I0]), decltype(descs[I0]),
decltype(descs[I1]), decltype(descs[I1]),
decltype(descs[I2]), decltype(descs[I2]),
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" #include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_dynamic_gemm_v1r3.hpp" #include "driver_dynamic_gemm_dlops_v1r3.hpp"
template <typename TInWei, template <typename TInWei,
typename TAcc, typename TAcc,
...@@ -14,7 +14,7 @@ template <typename TInWei, ...@@ -14,7 +14,7 @@ template <typename TInWei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk( void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths, const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths, const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths, const OutLengths& out_n_ho_wo_k_lengths,
...@@ -56,7 +56,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk( ...@@ -56,7 +56,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk(
const auto out_n_ho_wo_k_desc = const auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
#if 0 #if 1
// [M, N, K0, K1] = [128, 128, 8, 1] for fp32 // [M, N, K0, K1] = [128, 128, 8, 1] for fp32
// cdata = 64, BlockSize = 256 // cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -70,10 +70,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk( ...@@ -70,10 +70,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk(
constexpr index_t GemmN1PerThreadN111 = 4; constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1100 = 8; using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8; using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>; using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>;
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
...@@ -102,10 +100,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk( ...@@ -102,10 +100,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk(
constexpr index_t GemmN1PerThreadN111 = 4; constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1100 = 8; using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8; using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>; using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>;
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
...@@ -134,10 +130,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk( ...@@ -134,10 +130,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk(
constexpr index_t GemmN1PerThreadN111 = 4; constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1100 = 8; using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8; using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>; using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>;
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
...@@ -211,12 +205,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk( ...@@ -211,12 +205,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk(
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = driver_dynamic_gemm_v1r3< float ave_time = driver_dynamic_gemm_dlops_v1r3<
BlockSize, BlockSize,
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(out_gemmm_gemmn_grid_desc),
...@@ -226,10 +220,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk( ...@@ -226,10 +220,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk(
GemmM1PerThreadM111, GemmM1PerThreadM111,
GemmN1PerThreadN111, GemmN1PerThreadN111,
GemmKPerThread, GemmKPerThread,
GemmM11N11ThreadClusterM1100, GemmM11N11ThreadClusterM110Xs,
GemmM11N11ThreadClusterN1100, GemmM11N11ThreadClusterN110Xs,
GemmM11N11ThreadClusterM1101,
GemmM11N11ThreadClusterN1101,
GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1, GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1,
GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1, GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1,
Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder
......
...@@ -145,7 +145,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk ...@@ -145,7 +145,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(out_gemmm_gemmn_grid_desc),
......
...@@ -165,7 +165,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh ...@@ -165,7 +165,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(out_gemmm_gemmn_grid_desc),
......
...@@ -229,7 +229,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh ...@@ -229,7 +229,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(out_gemmm_gemmn_grid_desc),
......
...@@ -288,7 +288,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh ...@@ -288,7 +288,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(out_gemmm_gemmn_grid_desc),
......
#include <unistd.h> #include <unistd.h>
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" #include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp" #include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp"
template <typename TInWei, template <typename TInWei,
ck::index_t InWeiVectorSize, ck::index_t InWeiVectorSize,
...@@ -15,7 +15,7 @@ template <typename TInWei, ...@@ -15,7 +15,7 @@ template <typename TInWei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths, const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths, const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths, const OutLengths& out_n_k_ho_wo_lengths,
...@@ -145,9 +145,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -145,9 +145,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr auto conv_driver = constexpr auto conv_driver =
#if 0 #if 0
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
#else #else
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
#endif #endif
<BlockSize, <BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type, typename vector_type<TInWei, InWeiVectorSize>::type,
......
#pragma once
#include <unistd.h> #include <unistd.h>
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" #include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_contraction_v1r2.hpp" #include "driver_dynamic_contraction_dlops_v1r2.hpp"
template <typename TInWei, template <typename TInWei,
typename TAcc, typename TAcc,
...@@ -14,7 +15,7 @@ template <typename TInWei, ...@@ -14,7 +15,7 @@ template <typename TInWei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths, const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths, const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths, const OutLengths& out_n_k_ho_wo_lengths,
...@@ -66,10 +67,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( ...@@ -66,10 +67,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw(
constexpr index_t BN1PerThreadBN11 = 4; constexpr index_t BN1PerThreadBN11 = 4;
constexpr index_t BK0PerThread = 1; constexpr index_t BK0PerThread = 1;
constexpr index_t BM10BN10ThreadClusterBM100 = 8; using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>;
constexpr index_t BM10BN10ThreadClusterBN100 = 8; using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>;
constexpr index_t BM10BN10ThreadClusterBM101 = 2;
constexpr index_t BM10BN10ThreadClusterBN101 = 2;
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>; using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
...@@ -100,10 +99,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( ...@@ -100,10 +99,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw(
constexpr index_t BN1PerThreadBN11 = 4; constexpr index_t BN1PerThreadBN11 = 4;
constexpr index_t BK0PerThread = 1; constexpr index_t BK0PerThread = 1;
constexpr index_t BM10BN10ThreadClusterBM100 = 8; using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>;
constexpr index_t BM10BN10ThreadClusterBN100 = 8; using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>;
constexpr index_t BM10BN10ThreadClusterBM101 = 2;
constexpr index_t BM10BN10ThreadClusterBN101 = 2;
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>; using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>;
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>; using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
...@@ -183,12 +180,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( ...@@ -183,12 +180,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw(
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = driver_dynamic_contraction_v1r2< float ave_time = driver_dynamic_contraction_dlops_v1r2<
BlockSize, BlockSize,
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(wei_grid_desc_gk0_gm0_gm1_gk1), decltype(wei_grid_desc_gk0_gm0_gm1_gk1),
decltype(in_grid_desc_gk0_gn0_gn1_gk1), decltype(in_grid_desc_gk0_gn0_gn1_gk1),
decltype(out_grid_desc_gm0_gm1_gn0_gn1), decltype(out_grid_desc_gm0_gm1_gn0_gn1),
...@@ -198,10 +195,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( ...@@ -198,10 +195,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw(
BM1PerThreadBM11, BM1PerThreadBM11,
BN1PerThreadBN11, BN1PerThreadBN11,
BK0PerThread, BK0PerThread,
BM10BN10ThreadClusterBM100, BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN100, BM10BN10ThreadClusterBN10Xs,
BM10BN10ThreadClusterBM101,
BM10BN10ThreadClusterBN101,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
Sequence<1, 2, 3, 0, 4>, // ABlockTransferThreadClusterArrangeOrder Sequence<1, 2, 3, 0, 4>, // ABlockTransferThreadClusterArrangeOrder
......
#ifndef CK_DRIVER_DYNAMIC_CONTRACTION_V1R2_HPP #ifndef DRIVER_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP
#define CK_DRIVER_DYNAMIC_CONTRACTION_V1R2_HPP #define DRIVER_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_contraction_v1r2.hpp" #include "gridwise_dynamic_contraction_dlops_v1r2.hpp"
namespace ck { template <ck::index_t BlockSize,
template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation, ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AGridDesc_GK0_GM0_GM1_GK1, typename AGridDesc_GK0_GM0_GM1_GK1,
typename BGridDesc_GK0_GN0_GN1_GK1, typename BGridDesc_GK0_GN0_GN1_GK1,
typename CGridDesc_GM0_GM1_GN0_GN1, typename CGridDesc_GM0_GM1_GN0_GN1,
index_t GM1PerBlockGM11, ck::index_t GM1PerBlockGM11,
index_t GN1PerBlockGN11, ck::index_t GN1PerBlockGN11,
index_t GK0PerBlock, ck::index_t GK0PerBlock,
index_t BM1PerThreadBM11, ck::index_t BM1PerThreadBM11,
index_t BN1PerThreadBN11, ck::index_t BN1PerThreadBN11,
index_t BK0PerThread, ck::index_t BK0PerThread,
index_t BM10BN10ThreadClusterBM100, typename BM10BN10ThreadClusterBM10Xs,
index_t BM10BN10ThreadClusterBN100, typename BM10BN10ThreadClusterBN10Xs,
index_t BM10BN10ThreadClusterBM101,
index_t BM10BN10ThreadClusterBN101,
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
...@@ -41,28 +37,30 @@ template <index_t BlockSize, ...@@ -41,28 +37,30 @@ template <index_t BlockSize,
typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder, typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, ck::index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks, typename AGridIteratorHacks,
typename BGridIteratorHacks, typename BGridIteratorHacks,
typename CGridIteratorHacks, typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowIteratorHacks>
__host__ float __host__ float
driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
const FloatAB* p_b_grid, const FloatAB* p_b_grid,
FloatC* p_c_grid, FloatC* p_c_grid,
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1, const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1, const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1, const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1,
AGridIteratorHacks, AGridIteratorHacks,
BGridIteratorHacks, BGridIteratorHacks,
CGridIteratorHacks, CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks, BGridMoveSliceWindowIteratorHacks,
index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -72,7 +70,7 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, ...@@ -72,7 +70,7 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
// GEMM // GEMM
using GridwiseContraction = using GridwiseContraction =
GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
...@@ -87,10 +85,8 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, ...@@ -87,10 +85,8 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
BM1PerThreadBM11, BM1PerThreadBM11,
BN1PerThreadBN11, BN1PerThreadBN11,
BK0PerThread, BK0PerThread,
BM10BN10ThreadClusterBM100, BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN100, BM10BN10ThreadClusterBN10Xs,
BM10BN10ThreadClusterBM101,
BM10BN10ThreadClusterBN101,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
...@@ -182,7 +178,7 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, ...@@ -182,7 +178,7 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_contraction_v1r2< const auto kernel = kernel_dynamic_contraction_dlops_v1r2<
GridwiseContraction, GridwiseContraction,
FloatAB, FloatAB,
FloatC, FloatC,
...@@ -209,7 +205,7 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, ...@@ -209,7 +205,7 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_contraction_v1r2< const auto kernel = kernel_dynamic_contraction_dlops_v1r2<
GridwiseContraction, GridwiseContraction,
FloatAB, FloatAB,
FloatC, FloatC,
...@@ -236,7 +232,7 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, ...@@ -236,7 +232,7 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_contraction_v1r2< const auto kernel = kernel_dynamic_contraction_dlops_v1r2<
GridwiseContraction, GridwiseContraction,
FloatAB, FloatAB,
FloatC, FloatC,
...@@ -263,7 +259,7 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, ...@@ -263,7 +259,7 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
} }
else else
{ {
const auto kernel = kernel_dynamic_contraction_v1r2< const auto kernel = kernel_dynamic_contraction_dlops_v1r2<
GridwiseContraction, GridwiseContraction,
FloatAB, FloatAB,
FloatC, FloatC,
...@@ -291,6 +287,4 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, ...@@ -291,6 +287,4 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
return ave_time; return ave_time;
} }
} // namespace ck
#endif #endif
#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP #ifndef DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP #define DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_v2.hpp" #include "gridwise_dynamic_gemm_dlops_v2.hpp"
#include "gridwise_operation_wrapper.hpp" #include "gridwise_operation_wrapper.hpp"
namespace ck { template <ck::index_t BlockSize,
template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
index_t KPerBlock, ck::index_t KPerBlock,
index_t HoPerBlock, ck::index_t HoPerBlock,
index_t WoPerBlock, ck::index_t WoPerBlock,
index_t EPerBlock, ck::index_t EPerBlock,
index_t KPerThread, ck::index_t KPerThread,
index_t HoPerThread, ck::index_t HoPerThread,
index_t WoPerThread, ck::index_t WoPerThread,
index_t EPerThread, ck::index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E_K, typename ABlockTransferThreadSliceLengths_E_K,
typename ABlockTransferThreadClusterLengths_E_K, typename ABlockTransferThreadClusterLengths_E_K,
index_t ABlockTransferSrcScalarPerVector_E, ck::index_t ABlockTransferSrcScalarPerVector_E,
index_t ABlockTransferDstScalarPerVector_K, ck::index_t ABlockTransferDstScalarPerVector_K,
index_t BThreadTransferSrcScalarPerVector_W, ck::index_t BThreadTransferSrcScalarPerVector_W,
index_t CThreadTransferDstScalarPerVector_W> ck::index_t CThreadTransferDstScalarPerVector_W>
struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
{ {
template <typename... Wei, template <typename... Wei,
typename... In, typename... In,
...@@ -36,9 +34,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -36,9 +34,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, __host__ void Run(const ck::DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc, const ck::DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc, const ck::DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -47,6 +45,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -47,6 +45,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const FloatAB* __restrict__ p_in_global, const FloatAB* __restrict__ p_in_global,
FloatC* __restrict__ p_out_global) const FloatC* __restrict__ p_out_global) const
{ {
using namespace ck;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -169,12 +169,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -169,12 +169,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
#if 1 #if 1
// GEMM // GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3< using gridwise_gemm = GridwiseDynamicGemmDlops_km_kn_mn_v3<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
decltype(out_k_n_ho_wo_global_desc), decltype(out_k_n_ho_wo_global_desc),
...@@ -349,5 +349,4 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -349,5 +349,4 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
#endif #endif
} }
}; };
} // namespace ck
#endif #endif
#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_OUTPAD_HPP #ifndef DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_OUTPAD_HPP #define DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_v2.hpp" #include "gridwise_dynamic_gemm_dlops_v2.hpp"
#include "gridwise_operation_wrapper.hpp" #include "gridwise_operation_wrapper.hpp"
namespace ck { template <ck::index_t BlockSize,
template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
index_t KPerBlock, ck::index_t KPerBlock,
index_t HoPerBlock, ck::index_t HoPerBlock,
index_t WoPerBlock, ck::index_t WoPerBlock,
index_t EPerBlock, ck::index_t EPerBlock,
index_t KPerThread, ck::index_t KPerThread,
index_t HoPerThread, ck::index_t HoPerThread,
index_t WoPerThread, ck::index_t WoPerThread,
index_t EPerThread, ck::index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E_K, typename ABlockTransferThreadSliceLengths_E_K,
typename ABlockTransferThreadClusterLengths_E_K, typename ABlockTransferThreadClusterLengths_E_K,
index_t ABlockTransferSrcScalarPerVector_E, ck::index_t ABlockTransferSrcScalarPerVector_E,
index_t ABlockTransferDstScalarPerVector_K, ck::index_t ABlockTransferDstScalarPerVector_K,
index_t BThreadTransferSrcScalarPerVector_W, ck::index_t BThreadTransferSrcScalarPerVector_W,
index_t CThreadTransferDstScalarPerVector_W> ck::index_t CThreadTransferDstScalarPerVector_W>
struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
{ {
template <typename... Wei, template <typename... Wei,
typename... In, typename... In,
...@@ -36,9 +34,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -36,9 +34,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, __host__ void Run(const ck::DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc, const ck::DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc, const ck::DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -47,6 +45,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -47,6 +45,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB* __restrict__ p_in_global, const FloatAB* __restrict__ p_in_global,
FloatC* __restrict__ p_out_global) const FloatC* __restrict__ p_out_global) const
{ {
using namespace ck;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -181,12 +181,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -181,12 +181,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
Sequence<0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0>{}));
// GEMM // GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3< using gridwise_gemm = GridwiseDynamicGemmDlops_km_kn_mn_v3<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
...@@ -364,5 +364,4 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -364,5 +364,4 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
} }
} }
}; };
} // namespace ck
#endif #endif
#ifndef CK_DRIVER_DYNAMIC_GEMM_V1R2 #ifndef DRIVER_DYNAMIC_GEMM_DLOPS_V1R2
#define CK_DRIVER_DYNAMIC_GEMM_V1R2 #define DRIVER_DYNAMIC_GEMM_DLOPS_V1R2
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_v1r2.hpp" #include "gridwise_dynamic_gemm_dlops_v1r2.hpp"
namespace ck { template <ck::index_t BlockSize,
template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation, ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AKMGridDesc, typename AKMGridDesc,
typename BKNGridDesc, typename BKNGridDesc,
typename CMNGridDesc, typename CMNGridDesc,
index_t MPerBlock, ck::index_t MPerBlock,
index_t NPerBlock, ck::index_t NPerBlock,
index_t KPerBlock, ck::index_t KPerBlock,
index_t M1PerThread, ck::index_t M1PerThread,
index_t N1PerThread, ck::index_t N1PerThread,
index_t KPerThread, ck::index_t KPerThread,
index_t M1N1ThreadClusterM10, ck::index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10, ck::index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11, ck::index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11, ck::index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_K_M0_M1, typename ABlockTransferThreadSliceLengths_K_M0_M1,
typename ABlockTransferThreadClusterLengths_K_M0_M1, typename ABlockTransferThreadClusterLengths_K_M0_M1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim, ck::index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector, ck::index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M1, ck::index_t ABlockTransferDstScalarPerVector_M1,
bool AThreadTransferSrcResetCoordinateAfterRun, bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N0_N1, typename BBlockTransferThreadSliceLengths_K_N0_N1,
typename BBlockTransferThreadClusterLengths_K_N0_N1, typename BBlockTransferThreadClusterLengths_K_N0_N1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim, ck::index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector, ck::index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N1, ck::index_t BBlockTransferDstScalarPerVector_N1,
bool BThreadTransferSrcResetCoordinateAfterRun, bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, ck::index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks, typename AGridIteratorHacks,
typename BGridIteratorHacks, typename BGridIteratorHacks,
typename CGridIteratorHacks, typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowIteratorHacks>
__host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
const FloatAB* p_b_grid, const FloatAB* p_b_grid,
FloatC* p_c_grid, FloatC* p_c_grid,
const AKMGridDesc& a_k_m_grid_desc, const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc, const BKNGridDesc& b_k_n_grid_desc,
const CMNGridDesc& c_m_n_grid_desc, const CMNGridDesc& c_m_n_grid_desc,
AGridIteratorHacks, AGridIteratorHacks,
BGridIteratorHacks, BGridIteratorHacks,
CGridIteratorHacks, CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks, BGridMoveSliceWindowIteratorHacks,
index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -73,48 +73,48 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -73,48 +73,48 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
// GEMM // GEMM
using GridwiseGemm = using GridwiseGemm =
GridwiseDynamicGemm_km_kn_mn_v1r2<BlockSize, GridwiseDynamicGemmDlops_km_kn_mn_v1r2<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
AKMGridDesc, AKMGridDesc,
BKNGridDesc, BKNGridDesc,
CMNGridDesc, CMNGridDesc,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
M1PerThread, M1PerThread,
N1PerThread, N1PerThread,
KPerThread, KPerThread,
M1N1ThreadClusterM10, M1N1ThreadClusterM10,
M1N1ThreadClusterN10, M1N1ThreadClusterN10,
M1N1ThreadClusterM11, M1N1ThreadClusterM11,
M1N1ThreadClusterN11, M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_K_M0_M1, ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1, ABlockTransferThreadClusterLengths_K_M0_M1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M1, ABlockTransferDstScalarPerVector_M1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N0_N1, BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1, BBlockTransferThreadClusterLengths_K_N0_N1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N1, BBlockTransferDstScalarPerVector_N1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGridIteratorHacks, AGridIteratorHacks,
BGridIteratorHacks, BGridIteratorHacks,
CGridIteratorHacks, CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>; BGridMoveSliceWindowIteratorHacks>;
const auto M = a_k_m_grid_desc.GetLength(I1); const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1); const auto N = b_k_n_grid_desc.GetLength(I1);
...@@ -122,7 +122,8 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -122,7 +122,8 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc)) if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc))
{ {
throw std::runtime_error("wrong! GridwiseDynamicGemm_km_kn_mn_v1r2 has invalid setting"); throw std::runtime_error(
"wrong! GridwiseDynamicGemmDlops_km_kn_mn_v1r2 has invalid setting");
} }
const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
...@@ -173,15 +174,15 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -173,15 +174,15 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_v1r2<GridwiseGemm, kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true, true,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -200,15 +201,15 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -200,15 +201,15 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_v1r2<GridwiseGemm, kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true, true,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -227,15 +228,15 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -227,15 +228,15 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_v1r2<GridwiseGemm, kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false, false,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -254,15 +255,15 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -254,15 +255,15 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
else else
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_v1r2<GridwiseGemm, kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false, false,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -298,15 +299,15 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -298,15 +299,15 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_v1r2<GridwiseGemm, kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true, true,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
...@@ -318,23 +319,23 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -318,23 +319,23 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
(void __CONSTANT__*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_v1r2<GridwiseGemm, kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true, true,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
...@@ -346,23 +347,23 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -346,23 +347,23 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
(void __CONSTANT__*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_v1r2<GridwiseGemm, kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false, false,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
...@@ -374,23 +375,23 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -374,23 +375,23 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
(void __CONSTANT__*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
} }
else else
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_v1r2<GridwiseGemm, kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false, false,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
...@@ -402,15 +403,13 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -402,15 +403,13 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
(void __CONSTANT__*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
} }
return ave_time; return ave_time;
#endif #endif
} }
} // namespace ck
#endif #endif
#ifndef CK_DRIVER_DYNAMIC_GEMM_v1r3 #ifndef DRIVER_DYNAMIC_GEMM_DLOPS_V1R3
#define CK_DRIVER_DYNAMIC_GEMM_v1r3 #define DRIVER_DYNAMIC_GEMM_DLOPS_V1R3
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_v1r3.hpp" #include "gridwise_dynamic_gemm_dlops_v1r3.hpp"
namespace ck { template <ck::index_t BlockSize,
template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation, ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AK0MK1GridDesc, typename AK0MK1GridDesc,
typename BK0NK1GridDesc, typename BK0NK1GridDesc,
typename CMNGridDesc, typename CMNGridDesc,
index_t MPerBlock, ck::index_t MPerBlock,
index_t NPerBlock, ck::index_t NPerBlock,
index_t KPerBlock, ck::index_t KPerBlock,
index_t M1PerThread, ck::index_t M1PerThread,
index_t N1PerThread, ck::index_t N1PerThread,
index_t KPerThread, ck::index_t KPerThread,
index_t M1N1ThreadClusterM10, typename M1N1ThreadClusterM1Xs,
index_t M1N1ThreadClusterN10, typename M1N1ThreadClusterN1Xs,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1, typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1, typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
...@@ -41,27 +37,29 @@ template <index_t BlockSize, ...@@ -41,27 +37,29 @@ template <index_t BlockSize,
typename BBlockTransferSrcVectorTensorContiguousDimOrder, typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, ck::index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks, typename AGridIteratorHacks,
typename BGridIteratorHacks, typename BGridIteratorHacks,
typename CGridIteratorHacks, typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowIteratorHacks>
__host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid, __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
const FloatAB* p_b_grid, const FloatAB* p_b_grid,
FloatC* p_c_grid, FloatC* p_c_grid,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc, const CMNGridDesc& c_m_n_grid_desc,
AGridIteratorHacks, AGridIteratorHacks,
BGridIteratorHacks, BGridIteratorHacks,
CGridIteratorHacks, CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks, BGridMoveSliceWindowIteratorHacks,
index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -71,46 +69,44 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid, ...@@ -71,46 +69,44 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid,
// GEMM // GEMM
using GridwiseGemm = using GridwiseGemm =
GridwiseDynamicGemm_km_kn_mn_v1r3<BlockSize, GridwiseDynamicGemmDlops_km_kn_mn_v1r3<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
AK0MK1GridDesc, AK0MK1GridDesc,
BK0NK1GridDesc, BK0NK1GridDesc,
CMNGridDesc, CMNGridDesc,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
M1PerThread, M1PerThread,
N1PerThread, N1PerThread,
KPerThread, KPerThread,
M1N1ThreadClusterM10, M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN10, M1N1ThreadClusterN1Xs,
M1N1ThreadClusterM11, ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
M1N1ThreadClusterN11, ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferSrcAccessOrder,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferSrcAccessOrder,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder, CThreadTransferSrcDstAccessOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstVectorDim,
CThreadTransferSrcDstAccessOrder, CThreadTransferDstScalarPerVector,
CThreadTransferSrcDstVectorDim, AGridIteratorHacks,
CThreadTransferDstScalarPerVector, BGridIteratorHacks,
AGridIteratorHacks, CGridIteratorHacks,
BGridIteratorHacks, AGridMoveSliceWindowIteratorHacks,
CGridIteratorHacks, BGridMoveSliceWindowIteratorHacks>;
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
const auto M = a_k0_m_k1_grid_desc.GetLength(I1); const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1); const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
...@@ -118,7 +114,8 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid, ...@@ -118,7 +114,8 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid,
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
{ {
throw std::runtime_error("wrong! GridwiseDynamicGemm_km_kn_mn_v1r3 has invalid setting"); throw std::runtime_error(
"wrong! GridwiseDynamicGemmDlops_km_kn_mn_v1r3 has invalid setting");
} }
const auto a_k0_m0_m1_k1_grid_desc = const auto a_k0_m0_m1_k1_grid_desc =
...@@ -173,15 +170,15 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid, ...@@ -173,15 +170,15 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_v1r3<GridwiseGemm, kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true, true,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -200,15 +197,15 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid, ...@@ -200,15 +197,15 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid,
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_v1r3<GridwiseGemm, kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true, true,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -227,15 +224,15 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid, ...@@ -227,15 +224,15 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid,
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_v1r3<GridwiseGemm, kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false, false,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -254,15 +251,15 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid, ...@@ -254,15 +251,15 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid,
else else
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_v1r3<GridwiseGemm, kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false, false,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -298,15 +295,15 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid, ...@@ -298,15 +295,15 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_v1r3<GridwiseGemm, kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true, true,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
...@@ -318,23 +315,23 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid, ...@@ -318,23 +315,23 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
(void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_v1r3<GridwiseGemm, kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true, true,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
...@@ -346,23 +343,23 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid, ...@@ -346,23 +343,23 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
(void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_v1r3<GridwiseGemm, kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false, false,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
...@@ -374,23 +371,23 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid, ...@@ -374,23 +371,23 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
(void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
} }
else else
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_v1r3<GridwiseGemm, kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false, false,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
...@@ -402,15 +399,13 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid, ...@@ -402,15 +399,13 @@ __host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
(void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
} }
return ave_time; return ave_time;
#endif #endif
} }
} // namespace ck
#endif #endif
#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3 #ifndef DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3
#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3 #define DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp" #include "gridwise_dynamic_gemm_xdlops_v2r3.hpp"
namespace ck { template <ck::index_t BlockSize,
template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation, ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AK0MK1GridDesc, typename AK0MK1GridDesc,
typename BK0NK1GridDesc, typename BK0NK1GridDesc,
typename CMNGridDesc, typename CMNGridDesc,
index_t MPerBlock, ck::index_t MPerBlock,
index_t NPerBlock, ck::index_t NPerBlock,
index_t KPerBlock, ck::index_t KPerBlock,
index_t MPerWave, ck::index_t MPerWave,
index_t NPerWave, ck::index_t NPerWave,
index_t K1, ck::index_t K1,
index_t MRepeat, ck::index_t MRepeat,
index_t NRepeat, ck::index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1, typename ABlockTransferThreadSliceLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim, ck::index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector, ck::index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1, ck::index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun, bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1, typename BBlockTransferThreadSliceLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim, ck::index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector, ck::index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1, ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun, bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, ck::index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks, typename AGridIteratorHacks,
typename BGridIteratorHacks, typename BGridIteratorHacks,
typename CGridIteratorHacks, typename CGridIteratorHacks,
...@@ -60,9 +58,11 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -60,9 +58,11 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
CGridIteratorHacks, CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks, BGridMoveSliceWindowIteratorHacks,
index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -176,23 +176,21 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -176,23 +176,21 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc); c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
float ave_time = launch_and_time_kernel( float ave_time =
kernel, launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
(void __CONSTANT__*)a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer(), (void CONSTANT*)c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); (void CONSTANT*)c_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
#endif #endif
return ave_time; return ave_time;
} }
} // namespace ck
#endif #endif
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "handle.hpp" #include "handle.hpp"
#include "hipCheck.hpp" #include "hipCheck.hpp"
#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
#include "online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp" #include "online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" #include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp" #include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp"
...@@ -35,6 +35,7 @@ enum ConvForwardAlgo ...@@ -35,6 +35,7 @@ enum ConvForwardAlgo
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
using namespace ck_driver;
using size_t = std::size_t; using size_t = std::size_t;
hipStream_t stream; hipStream_t stream;
...@@ -93,7 +94,7 @@ int main(int argc, char* argv[]) ...@@ -93,7 +94,7 @@ int main(int argc, char* argv[])
using in_data_t = float; using in_data_t = float;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = float; using out_data_t = float;
#elif 1 #elif 0
using in_data_t = half_t; using in_data_t = half_t;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = half_t; using out_data_t = half_t;
...@@ -225,25 +226,25 @@ int main(int argc, char* argv[]) ...@@ -225,25 +226,25 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw(); const auto tmp = f_make_for_device_nchw();
tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw* tunable = tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* tunable =
&default_tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw; &default_tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw;
online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t, online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw<
acc_data_t, in_data_t,
out_data_t>( acc_data_t,
handle, out_data_t>(handle,
tmp[I0], tmp[I0],
tmp[I1], tmp[I1],
tmp[I2], tmp[I2],
conv_strides, conv_strides,
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
in_right_pads, in_right_pads,
in, in,
wei, wei,
out_device, out_device,
tunable, tunable,
nrepeat); nrepeat);
} }
#endif #endif
...@@ -257,24 +258,105 @@ int main(int argc, char* argv[]) ...@@ -257,24 +258,105 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw(); const auto tmp = f_make_for_device_nchw();
const auto tunable = tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw{}; #if 1
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param = {
online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw<in_data_t, get_datatype_enum_from_type<in_data_t>::value,
acc_data_t, get_datatype_enum_from_type<acc_data_t>::value,
out_data_t>( get_datatype_enum_from_type<out_data_t>::value,
handle, 256,
tmp[I0], 4,
tmp[I1], 1,
tmp[I2], 128,
conv_strides, 32,
conv_dilations, 8,
in_left_pads, 4,
in_right_pads, 4,
in, 1,
wei, {8, 2},
out_device, {8, 2},
tunable, {4, 1, 1, 1, 1},
nrepeat); {2, 1, 1, 128, 1},
{4, 1, 1, 1, 1},
{1, 1, 1, 1, 1},
{1, 4, 1, 1, 1},
{8, 1, 1, 32, 1},
{1, 1, 1, 1, 1},
{1, 1, 1, 1, 1},
4,
true,
true};
#elif 0
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param = {
get_datatype_enum_from_type<in_data_t>::value,
get_datatype_enum_from_type<acc_data_t>::value,
get_datatype_enum_from_type<out_data_t>::value,
256,
4,
2,
128,
32,
8,
4,
4,
1,
{8, 2},
{8, 2},
{4, 1, 1, 1, 2},
{2, 1, 1, 128, 1},
{4, 1, 1, 1, 1},
{1, 1, 1, 1, 1},
{1, 4, 1, 1, 2},
{8, 1, 1, 32, 1},
{1, 1, 1, 1, 1},
{1, 1, 1, 1, 1},
4,
true,
true};
#elif 1
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param = {
get_datatype_enum_from_type<in_data_t>::value,
get_datatype_enum_from_type<acc_data_t>::value,
get_datatype_enum_from_type<out_data_t>::value,
256,
4,
4,
128,
32,
8,
4,
4,
1,
{8, 2},
{8, 2},
{4, 1, 1, 1, 4},
{2, 1, 1, 128, 1},
{4, 1, 1, 1, 1},
{1, 1, 1, 1, 1},
{1, 4, 1, 1, 4},
{8, 1, 1, 32, 1},
{1, 1, 1, 1, 1},
{1, 1, 1, 1, 1},
4,
true,
true};
#endif
online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw<
in_data_t,
acc_data_t,
out_data_t>(handle,
tmp[I0],
tmp[I1],
tmp[I2],
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
in,
wei,
out_device,
compile_param,
nrepeat);
} }
#endif #endif
...@@ -355,13 +437,15 @@ int main(int argc, char* argv[]) ...@@ -355,13 +437,15 @@ int main(int argc, char* argv[])
check_error(out_host, out_device); check_error(out_host, out_device);
#if 0
if(do_log) if(do_log)
{ {
LogRange(std::cout << "in : ", in.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
LogRange(std::cout << "wei: ", wei.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl;
LogRange(std::cout << "out_host : ", out_host.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
LogRange(std::cout << "out_device: ", out_device.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
} }
#endif
} }
delete handle; delete handle;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment