Unverified Commit f63a23ac authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

[MIOpen Downstream] Initial MIOpen integration (#52)

* update online kernel wrapper bundle all descriptors in a tuple

* change __CONSTANT__ to CONSTANT

* rename

* adding tuning

* added IsValidCompileParameter

* reorginze

* adding tunable for fp16 and int8

* fix kernel compile warning and bug fixes

* suppress warning about cast CONSTANT (address space 4) pointer

* fix building issue
parent 12649254
#include "common_header.hpp" #include "common_header.hpp"
#include "type_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_contraction_v1r2.hpp" #include "gridwise_dynamic_contraction_dlops_v1r2.hpp"
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" #include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
using namespace ck; using namespace ck;
using FloatAB = typename get_type_from_type_id<static_cast<char>(CK_PARAM_IN_WEI_DATATYPE)>::type; constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
using FloatAcc = typename get_type_from_type_id<static_cast<char>(CK_PARAM_ACC_DATATYPE)>::type; constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
using FloatC = typename get_type_from_type_id<static_cast<char>(CK_PARAM_OUT_DATATYPE)>::type; constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
constexpr index_t BlockSize = CK_PARAM_BlockSize; constexpr index_t BlockSize = CK_PARAM_BlockSize;
...@@ -19,13 +22,13 @@ constexpr auto GK1 = Number<CK_PARAM_GK1>{}; ...@@ -19,13 +22,13 @@ constexpr auto GK1 = Number<CK_PARAM_GK1>{};
constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11; constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11;
constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11; constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11;
constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock; constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock;
constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11; constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11;
constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11; constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11;
constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread; constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread;
constexpr index_t BM10BN10ThreadClusterBM100 = CK_PARAM_BM10BN10ThreadClusterBM100;
constexpr index_t BM10BN10ThreadClusterBN100 = CK_PARAM_BM10BN10ThreadClusterBN100; using BM10BN10ThreadClusterBM10Xs = Sequence<CK_PARAM_BM10BN10ThreadClusterBM10Xs>;
constexpr index_t BM10BN10ThreadClusterBM101 = CK_PARAM_BM10BN10ThreadClusterBM101; using BM10BN10ThreadClusterBN10Xs = Sequence<CK_PARAM_BM10BN10ThreadClusterBN10Xs>;
constexpr index_t BM10BN10ThreadClusterBN101 = CK_PARAM_BM10BN10ThreadClusterBN101;
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 =
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1>; Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1>;
...@@ -55,11 +58,11 @@ using CThreadTransferSrcDstAccessOrder = Sequence<3, 4, 5, 0, 1, 2> ...@@ -55,11 +58,11 @@ using CThreadTransferSrcDstAccessOrder = Sequence<3, 4, 5, 0, 1, 2>
constexpr index_t CThreadTransferSrcDstVectorDim = 5; constexpr index_t CThreadTransferSrcDstVectorDim = 5;
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP); constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HasMainKBlockLoop);
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP); constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HasDoubleTailKBlockLoop);
extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw_prepare( extern "C" __global__ void
index_t N, dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
index_t C, index_t C,
index_t Hi, index_t Hi,
index_t Wi, index_t Wi,
...@@ -74,10 +77,7 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k ...@@ -74,10 +77,7 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k
index_t InLeftPadW, index_t InLeftPadW,
index_t InRightPadH, index_t InRightPadH,
index_t InRightPadW, index_t InRightPadW,
void* p_a_grid_desc_gk0_gm0_gm10_gm11_gk1, void* p_desc_tuple)
void* p_b_grid_desc_gk0_gn0_gn10_gn11_gk1,
void* p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
void* p_c_grid_block_cluster_blockid_to_gm10_gn10)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -160,12 +160,12 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k ...@@ -160,12 +160,12 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
using GridwiseContraction = using GridwiseContraction =
GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
AGridDesc_GK0_GM0_GM1_GK1, AGridDesc_GK0_GM0_GM1_GK1,
BGridDesc_GK0_GN0_GN1_GK1, BGridDesc_GK0_GN0_GN1_GK1,
CGridDesc_GM0_GM1_GN0_GN1, CGridDesc_GM0_GM1_GN0_GN1,
...@@ -175,10 +175,8 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k ...@@ -175,10 +175,8 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k
BM1PerThreadBM11, BM1PerThreadBM11,
BN1PerThreadBN11, BN1PerThreadBN11,
BK0PerThread, BK0PerThread,
BM10BN10ThreadClusterBM100, BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN100, BM10BN10ThreadClusterBN10Xs,
BM10BN10ThreadClusterBM101,
BM10BN10ThreadClusterBN101,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
...@@ -202,47 +200,36 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k ...@@ -202,47 +200,36 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>; BGridMoveSliceWindowIteratorHacks>;
auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1); {
auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = auto desc_tuple =
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1); make_tuple(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = a_grid_desc_gk0_gm0_gm1_gk1),
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
b_grid_desc_gk0_gn0_gn1_gk1),
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
c_grid_desc_gm0_gm1_gn0_gn1); c_grid_desc_gm0_gm1_gn0_gn1),
auto c_grid_block_cluster_blockid_to_gm10_gn10 =
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
c_grid_desc_gm0_gm1_gn0_gn1); c_grid_desc_gm0_gm1_gn0_gn1));
if(hipThreadIdx_x == 0) *static_cast<decltype(desc_tuple)*>(p_desc_tuple) = desc_tuple;
{ }
*static_cast<decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1)*>(
p_a_grid_desc_gk0_gm0_gm10_gm11_gk1) = a_grid_desc_gk0_gm0_gm10_gm11_gk1;
*static_cast<decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1)*>(
p_b_grid_desc_gk0_gn0_gn10_gn11_gk1) = b_grid_desc_gk0_gn0_gn10_gn11_gk1;
*static_cast<decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1)*>(
p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1) = c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1;
*static_cast<decltype(c_grid_block_cluster_blockid_to_gm10_gn10)*>(
p_c_grid_block_cluster_blockid_to_gm10_gn10) =
c_grid_block_cluster_blockid_to_gm10_gn10;
};
}; };
extern "C" __global__ void extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_grid_desc_gk0_gm0_gm10_gm11_gk1, const void CONSTANT* p_desc_tuple)
const void __CONSTANT__* p_b_grid_desc_gk0_gn0_gn10_gn11_gk1,
const void __CONSTANT__* p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
const void __CONSTANT__* p_c_grid_block_cluster_blockid_to_gm10_gn10)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_hi_wi_desc = constexpr auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
...@@ -316,12 +303,12 @@ extern "C" __global__ void ...@@ -316,12 +303,12 @@ extern "C" __global__ void
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
using GridwiseContraction = using GridwiseContraction =
GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
AGridDesc_GK0_GM0_GM1_GK1, AGridDesc_GK0_GM0_GM1_GK1,
BGridDesc_GK0_GN0_GN1_GK1, BGridDesc_GK0_GN0_GN1_GK1,
CGridDesc_GM0_GM1_GN0_GN1, CGridDesc_GM0_GM1_GN0_GN1,
...@@ -331,10 +318,8 @@ extern "C" __global__ void ...@@ -331,10 +318,8 @@ extern "C" __global__ void
BM1PerThreadBM11, BM1PerThreadBM11,
BN1PerThreadBN11, BN1PerThreadBN11,
BK0PerThread, BK0PerThread,
BM10BN10ThreadClusterBM100, BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN100, BM10BN10ThreadClusterBN10Xs,
BM10BN10ThreadClusterBM101,
BM10BN10ThreadClusterBN101,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
...@@ -371,18 +356,23 @@ extern "C" __global__ void ...@@ -371,18 +356,23 @@ extern "C" __global__ void
decltype(GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( decltype(GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
c_grid_desc_gm0_gm1_gn0_gn1)); c_grid_desc_gm0_gm1_gn0_gn1));
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = using DescTuple = decltype(make_tuple(AGridDesc_GK0_GM0_GM10_GM11_GK1{},
*reinterpret_cast<const AGridDesc_GK0_GM0_GM10_GM11_GK1*>( BGridDesc_GK0_GN0_GN10_GN11_GK1{},
(const void*)p_a_grid_desc_gk0_gm0_gm10_gm11_gk1); CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{},
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = CGridBlockCluster_BlockId_To_GM10_GN10{}));
*reinterpret_cast<const BGridDesc_GK0_GN0_GN10_GN11_GK1*>(
(const void*)p_b_grid_desc_gk0_gn0_gn10_gn11_gk1); const auto desc_tuple = *reinterpret_cast<const DescTuple*>(
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = #pragma clang diagnostic push
*reinterpret_cast<const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1*>( #pragma clang diagnostic ignored "-Wold-style-cast"
(const void*)p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1); // TODO: how to cast?
const auto c_grid_block_cluster_blockid_to_gm10_gn10 = (const void*)p_desc_tuple
*reinterpret_cast<const CGridBlockCluster_BlockId_To_GM10_GN10*>( #pragma clang diagnostic pop
(const void*)p_c_grid_block_cluster_blockid_to_gm10_gn10); );
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0];
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1];
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = desc_tuple[I2];
const auto c_grid_block_cluster_blockid_to_gm10_gn10 = desc_tuple[I3];
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......
...@@ -12,17 +12,17 @@ ...@@ -12,17 +12,17 @@
#include "conv_common.hpp" #include "conv_common.hpp"
#include "host_conv.hpp" #include "host_conv.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1 #define USE_DYNAMIC_MODE 1
#define USE_CONV_FWD_V4R4_NCHW 1 #define USE_CONV_FWD_V4R4_NCHW 1
#define USE_CONV_FWD_V4R4R2_NHWC 0 #define USE_CONV_FWD_V4R4R2_NHWC 1
#define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V6R1_NCHW 1
#define USE_CONV_FWD_V5R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 #define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0 #define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
...@@ -301,9 +301,10 @@ int main(int argc, char* argv[]) ...@@ -301,9 +301,10 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw(); const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw<in_data_t,
acc_data_t, acc_data_t,
out_data_t>(tmp[I0], out_data_t>(
tmp[I0],
tmp[I1], tmp[I1],
tmp[I2], tmp[I2],
tmp[I3], tmp[I3],
...@@ -327,7 +328,7 @@ int main(int argc, char* argv[]) ...@@ -327,7 +328,7 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nhwc(); const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk<in_data_t,
acc_data_t, acc_data_t,
out_data_t>( out_data_t>(
tmp[I0], tmp[I0],
...@@ -354,9 +355,10 @@ int main(int argc, char* argv[]) ...@@ -354,9 +355,10 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw(); const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw<in_data_t,
acc_data_t, acc_data_t,
out_data_t>(tmp[I0], out_data_t>(
tmp[I0],
tmp[I1], tmp[I1],
tmp[I2], tmp[I2],
tmp[I3], tmp[I3],
...@@ -380,10 +382,11 @@ int main(int argc, char* argv[]) ...@@ -380,10 +382,11 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw(); const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw<in_data_t,
16, 16,
acc_data_t, acc_data_t,
out_data_t>(tmp[I0], out_data_t>(
tmp[I0],
tmp[I1], tmp[I1],
tmp[I2], tmp[I2],
tmp[I3], tmp[I3],
......
...@@ -264,7 +264,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx ...@@ -264,7 +264,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc), decltype(in_gemmm_gemmn_grid_desc),
......
...@@ -236,7 +236,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k ...@@ -236,7 +236,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc), decltype(in_gemmm_gemmn_grid_desc),
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_gemm_v1r2.hpp" #include "driver_dynamic_gemm_dlops_v1r2.hpp"
template <typename TInWei, template <typename TInWei,
typename TAcc, typename TAcc,
...@@ -14,7 +14,7 @@ template <typename TInWei, ...@@ -14,7 +14,7 @@ template <typename TInWei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths, const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths, const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths, const OutLengths& out_n_k_ho_wo_lengths,
...@@ -142,12 +142,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -142,12 +142,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = driver_dynamic_gemm_v1r2< float ave_time = driver_dynamic_gemm_dlops_v1r2<
BlockSize, BlockSize,
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk_gemmm_grid_desc), decltype(wei_gemmk_gemmm_grid_desc),
decltype(in_gemmk_gemmn_grid_desc), decltype(in_gemmk_gemmn_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(out_gemmm_gemmn_grid_desc),
......
...@@ -220,7 +220,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -220,7 +220,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(descs[I0]), decltype(descs[I0]),
decltype(descs[I1]), decltype(descs[I1]),
decltype(descs[I2]), decltype(descs[I2]),
......
...@@ -145,7 +145,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk ...@@ -145,7 +145,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(out_gemmm_gemmn_grid_desc),
......
...@@ -165,7 +165,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh ...@@ -165,7 +165,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(out_gemmm_gemmn_grid_desc),
......
...@@ -229,7 +229,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh ...@@ -229,7 +229,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(out_gemmm_gemmn_grid_desc),
......
...@@ -288,7 +288,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh ...@@ -288,7 +288,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(out_gemmm_gemmn_grid_desc),
......
#include <unistd.h> #include <unistd.h>
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" #include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp" #include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp"
template <typename TInWei, template <typename TInWei,
ck::index_t InWeiVectorSize, ck::index_t InWeiVectorSize,
...@@ -15,7 +15,7 @@ template <typename TInWei, ...@@ -15,7 +15,7 @@ template <typename TInWei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths, const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths, const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths, const OutLengths& out_n_k_ho_wo_lengths,
...@@ -145,9 +145,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -145,9 +145,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr auto conv_driver = constexpr auto conv_driver =
#if 0 #if 0
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
#else #else
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
#endif #endif
<BlockSize, <BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type, typename vector_type<TInWei, InWeiVectorSize>::type,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment