Commit 16467e0e authored by wangshaojie6's avatar wangshaojie6
Browse files

run gemm instance without code modification

parent 16f47b25
This diff is collapsed.
...@@ -224,7 +224,7 @@ struct DeviceGemmXdlSplitKCShuffleSmallGemm ...@@ -224,7 +224,7 @@ struct DeviceGemmXdlSplitKCShuffleSmallGemm
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<1, 1, 1, 1,
BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
...@@ -268,7 +268,7 @@ struct DeviceGemmXdlSplitKCShuffleSmallGemm ...@@ -268,7 +268,7 @@ struct DeviceGemmXdlSplitKCShuffleSmallGemm
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<1, 1, 1, 1,
BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4r2.hpp" #include "gridwise_gemm_xdlops_v2r4r2_static.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
#ifndef CK_RUN_KERNEL_AND_TIME #ifndef CK_RUN_KERNEL_AND_TIME
...@@ -20,7 +20,11 @@ namespace ck { ...@@ -20,7 +20,11 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename ALayout, template <index_t M_matrix,
index_t N_matrix,
index_t K_matrix,
index_t K_batch,
typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
typename ADataType, typename ADataType,
...@@ -246,7 +250,11 @@ struct DeviceGemmXdlSplitKCShuffleStatic ...@@ -246,7 +250,11 @@ struct DeviceGemmXdlSplitKCShuffleStatic
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N()); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N());
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2_static<
M_matrix,
N_matrix,
K_matrix,
K_batch,
BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
...@@ -290,7 +298,11 @@ struct DeviceGemmXdlSplitKCShuffleStatic ...@@ -290,7 +298,11 @@ struct DeviceGemmXdlSplitKCShuffleStatic
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2_static<
M_matrix,
N_matrix,
K_matrix,
K_batch,
BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
...@@ -444,7 +456,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic ...@@ -444,7 +456,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic
arg.block_2_ctile_map_)) arg.block_2_ctile_map_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting"); "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2_static has invalid setting");
} }
const index_t grid_size = const index_t grid_size =
...@@ -490,13 +502,35 @@ struct DeviceGemmXdlSplitKCShuffleStatic ...@@ -490,13 +502,35 @@ struct DeviceGemmXdlSplitKCShuffleStatic
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
// check validaty when using splitk
hipGetErrorString(hipMemset(
arg.p_c_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
launch_and_time_kernel({stream_config.stream_id_, false},
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}; };
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
{ {
if(kbatch == 1) if(kbatch == 1)
{ {
const auto kernel = kernel_gemm_xdlops_v2r4r2< const auto kernel = kernel_gemm_xdlops_v2r4r2_static<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
...@@ -514,7 +548,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic ...@@ -514,7 +548,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_v2r4r2< const auto kernel = kernel_gemm_xdlops_v2r4r2_static<
GridwiseGemmAtomicAdd, GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
...@@ -535,7 +569,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic ...@@ -535,7 +569,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic
{ {
if(kbatch == 1) if(kbatch == 1)
{ {
const auto kernel = kernel_gemm_xdlops_v2r4r2< const auto kernel = kernel_gemm_xdlops_v2r4r2_static<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
...@@ -553,7 +587,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic ...@@ -553,7 +587,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_v2r4r2< const auto kernel = kernel_gemm_xdlops_v2r4r2_static<
GridwiseGemmAtomicAdd, GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
......
...@@ -241,7 +241,7 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt ...@@ -241,7 +241,7 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
// 2D slices of column-vectors in 3D space // 2D slices of column-vectors in 3D space
// This C-tile map dynamically adjusts M01 when C-tile index is out of range // This C-tile map dynamically adjusts M01 when C-tile index is out of range
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N> template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N, index_t K_batch>
struct BlockToCTileMap_KSplit_M00_N0_M01Adapt_Static struct BlockToCTileMap_KSplit_M00_N0_M01Adapt_Static
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
......
...@@ -77,7 +77,11 @@ __global__ void ...@@ -77,7 +77,11 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <index_t BlockSize, template <index_t M_matrix,
index_t N_matrix,
index_t K_matrix,
index_t K_batch,
index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
...@@ -289,7 +293,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -289,7 +293,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptorStatic( __host__ __device__ static constexpr auto MakeCBlockClusterAdaptorStatic(
const CMNGridDesc& c_m_n_grid_desc) const CMNGridDesc& c_m_n_grid_desc)
{ {
return BlockToCTileMap_KSplit_M00_N0_M01Adapt_Static<MPerBlock, NPerBlock, CMNGridDesc>( return BlockToCTileMap_KSplit_M00_N0_M01Adapt_Static<MPerBlock, NPerBlock, CMNGridDesc, K_batch>(
c_m_n_grid_desc); c_m_n_grid_desc);
} }
......
...@@ -50,7 +50,7 @@ ...@@ -50,7 +50,7 @@
#define USEING_STATIC_KERNEL 1 #define USEING_STATIC_KERNEL 1
#define MNKB_0_8 1 #define MNKB_0_8 0
#define MNKB_1_4 0 #define MNKB_1_4 0
#define MNKB_2_8 0 #define MNKB_2_8 0
#define MNKB_3_5 0 #define MNKB_3_5 0
...@@ -60,23 +60,23 @@ ...@@ -60,23 +60,23 @@
#if MNKB_0_8 #if MNKB_0_8
#define M_matrix 16 #define M_matrix 16
#define N_matrix 4096 #define N_matrix 1152
#define K_matrix 12800 #define K_matrix 5120
#define K_batch 5 #define K_batch 8
#elif MNKB_1_4 #elif MNKB_1_4
#define M_matrix 16 #define M_matrix 16
#define N_matrix 4096 #define N_matrix 5120
#define K_matrix 12800 #define K_matrix 384
#define K_batch 5 #define K_batch 4
#elif MNKB_2_8 #elif MNKB_2_8
#define M_matrix 16 #define M_matrix 16
#define N_matrix 4096 #define N_matrix 1280
#define K_matrix 12800 #define K_matrix 5120
#define K_batch 5 #define K_batch 8
#elif MNKB_3_5 #elif MNKB_3_5
#define M_matrix 16 #define M_matrix 16
#define N_matrix 4096 #define N_matrix 5120
#define K_matrix 12800 #define K_matrix 1280
#define K_batch 5 #define K_batch 5
#elif MNKB_4_5 #elif MNKB_4_5
#define M_matrix 16 #define M_matrix 16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment