Commit 5b242405 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent f1403dac
......@@ -17,8 +17,8 @@ template <index_t BlockSize,
index_t WoPerBlock,
index_t EPerBlock,
index_t KPerThread,
index_t HPerThread,
index_t WPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t EPerThread,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
......@@ -178,8 +178,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
WoPerBlock,
EPerBlock,
KPerThread,
HPerThread,
WPerThread,
HoPerThread,
WoPerThread,
EPerThread,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
......
......@@ -3,10 +3,10 @@
template <typename GridwiseOp, typename... Xs>
__global__ void
#if 0
__launch_bounds__(256, 2)
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
run_gridwise_operation(Xs... xs)
run_gridwise_operation(Xs... xs)
{
GridwiseOp{}.Run(xs...);
}
......
......@@ -7,6 +7,10 @@
#endif
#include "bfloat16_dev.hpp"
// device backend
#define CK_DEVICE_BACKEND_AMD 1
// GPU ID
#if 1
#define CK_AMD_GPU_GFX906 1
#elif 0
......@@ -15,22 +19,29 @@
#define CK_AMD_GPU_GFX1030 1
#endif
// HIP version
#ifndef CK_HIP_VERSION_FLAT
#define CK_HIP_VERSION_FLAT 0
#endif
// launch bounds
#define CK_USE_LAUNCH_BOUNDS 1
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256
#define CK_MIN_BLOCK_PER_CU 1
#endif
// buffer resourse
#if defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(CK_AMD_GPU_GFX1030)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#endif
#ifndef CK_HIP_VERSION_FLAT
#define CK_HIP_VERSION_FLAT 0
#endif
// multi index
#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0
// device backend
#define CK_DEVICE_BACKEND_AMD 1
// AMD inline asm
#ifndef CK_USE_AMD_INLINE_ASM
#define CK_USE_AMD_INLINE_ASM 1
......
......@@ -133,6 +133,39 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 1
// cdata = 64, BlockSize 64, 16x256x2
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 1;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 1
// cdata = 64, BlockSize 64, 16x256x4
......
......@@ -70,15 +70,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
// cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 256;
constexpr index_t KPerBlock = 16;
constexpr index_t HPerBlock = 16;
constexpr index_t WPerBlock = 16;
constexpr index_t CYXPerBlock = 4;
constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 16;
constexpr index_t WoPerBlock = 16;
constexpr index_t EPerBlock = 4;
constexpr index_t KPerThread = 4;
constexpr index_t HPerThread = 2;
constexpr index_t WPerThread = 2;
constexpr index_t CYXPerThread = 4;
constexpr index_t KPerThread = 4;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
......@@ -97,13 +97,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
TDevice,
TDevice,
KPerBlock,
HPerBlock,
WPerBlock,
CYXPerBlock,
HoPerBlock,
WoPerBlock,
EPerBlock,
KPerThread,
HPerThread,
WPerThread,
CYXPerThread,
HoPerThread,
WoPerThread,
EPerThread,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
GemmABlockTransferSrcScalarPerVector_GemmK,
......
......@@ -34,8 +34,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
......@@ -736,7 +736,7 @@ int main(int argc, char* argv[])
LeftPads{},
RightPads{},
nrepeat);
#elif 1
#elif 0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
in_vector_size,
acc_data_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment