Commit 5b242405 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent f1403dac
...@@ -17,8 +17,8 @@ template <index_t BlockSize, ...@@ -17,8 +17,8 @@ template <index_t BlockSize,
index_t WoPerBlock, index_t WoPerBlock,
index_t EPerBlock, index_t EPerBlock,
index_t KPerThread, index_t KPerThread,
index_t HPerThread, index_t HoPerThread,
index_t WPerThread, index_t WoPerThread,
index_t EPerThread, index_t EPerThread,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM, typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM, typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
...@@ -178,8 +178,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -178,8 +178,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
WoPerBlock, WoPerBlock,
EPerBlock, EPerBlock,
KPerThread, KPerThread,
HPerThread, HoPerThread,
WPerThread, WoPerThread,
EPerThread, EPerThread,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
......
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
template <typename GridwiseOp, typename... Xs> template <typename GridwiseOp, typename... Xs>
__global__ void __global__ void
#if 0 #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(256, 2) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
run_gridwise_operation(Xs... xs) run_gridwise_operation(Xs... xs)
{ {
GridwiseOp{}.Run(xs...); GridwiseOp{}.Run(xs...);
} }
......
...@@ -7,6 +7,10 @@ ...@@ -7,6 +7,10 @@
#endif #endif
#include "bfloat16_dev.hpp" #include "bfloat16_dev.hpp"
// device backend
#define CK_DEVICE_BACKEND_AMD 1
// GPU ID
#if 1 #if 1
#define CK_AMD_GPU_GFX906 1 #define CK_AMD_GPU_GFX906 1
#elif 0 #elif 0
...@@ -15,22 +19,29 @@ ...@@ -15,22 +19,29 @@
#define CK_AMD_GPU_GFX1030 1 #define CK_AMD_GPU_GFX1030 1
#endif #endif
// HIP version
#ifndef CK_HIP_VERSION_FLAT
#define CK_HIP_VERSION_FLAT 0
#endif
// launch bounds
#define CK_USE_LAUNCH_BOUNDS 1
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256
#define CK_MIN_BLOCK_PER_CU 1
#endif
// buffer resourse
#if defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908) #if defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(CK_AMD_GPU_GFX1030) #elif defined(CK_AMD_GPU_GFX1030)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#endif #endif
#ifndef CK_HIP_VERSION_FLAT
#define CK_HIP_VERSION_FLAT 0
#endif
// multi index // multi index
#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0 #define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0
// device backend
#define CK_DEVICE_BACKEND_AMD 1
// AMD inline asm // AMD inline asm
#ifndef CK_USE_AMD_INLINE_ASM #ifndef CK_USE_AMD_INLINE_ASM
#define CK_USE_AMD_INLINE_ASM 1 #define CK_USE_AMD_INLINE_ASM 1
......
...@@ -133,6 +133,39 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -133,6 +133,39 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 1
// cdata = 64, BlockSize 64, 16x256x2
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 1;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 1 #elif 1
// cdata = 64, BlockSize 64, 16x256x4 // cdata = 64, BlockSize 64, 16x256x4
......
...@@ -70,15 +70,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -70,15 +70,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t HPerBlock = 16; constexpr index_t HoPerBlock = 16;
constexpr index_t WPerBlock = 16; constexpr index_t WoPerBlock = 16;
constexpr index_t CYXPerBlock = 4; constexpr index_t EPerBlock = 4;
constexpr index_t KPerThread = 4; constexpr index_t KPerThread = 4;
constexpr index_t HPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr index_t WPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t CYXPerThread = 4; constexpr index_t EPerThread = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
...@@ -97,13 +97,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -97,13 +97,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
TDevice, TDevice,
TDevice, TDevice,
KPerBlock, KPerBlock,
HPerBlock, HoPerBlock,
WPerBlock, WoPerBlock,
CYXPerBlock, EPerBlock,
KPerThread, KPerThread,
HPerThread, HoPerThread,
WPerThread, WoPerThread,
CYXPerThread, EPerThread,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
GemmABlockTransferSrcScalarPerVector_GemmK, GemmABlockTransferSrcScalarPerVector_GemmK,
......
...@@ -34,8 +34,8 @@ int main(int argc, char* argv[]) ...@@ -34,8 +34,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
...@@ -736,7 +736,7 @@ int main(int argc, char* argv[]) ...@@ -736,7 +736,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1 #elif 0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
in_vector_size, in_vector_size,
acc_data_t, acc_data_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment