Commit 16467e0e authored by wangshaojie6's avatar wangshaojie6
Browse files

run gemm instance without code modification

parent 16f47b25
This diff is collapsed.
......@@ -224,7 +224,7 @@ struct DeviceGemmXdlSplitKCShuffleSmallGemm
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<1, 1, 1, 1,
BlockSize,
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
......@@ -268,7 +268,7 @@ struct DeviceGemmXdlSplitKCShuffleSmallGemm
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
// GridwiseGemm
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<1, 1, 1, 1,
BlockSize,
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
......
......@@ -9,7 +9,7 @@
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4r2.hpp"
#include "gridwise_gemm_xdlops_v2r4r2_static.hpp"
#include "gemm_specialization.hpp"
#ifndef CK_RUN_KERNEL_AND_TIME
......@@ -20,7 +20,11 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
template <index_t M_matrix,
index_t N_matrix,
index_t K_matrix,
index_t K_batch,
typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
......@@ -246,7 +250,11 @@ struct DeviceGemmXdlSplitKCShuffleStatic
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N());
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2_static<
M_matrix,
N_matrix,
K_matrix,
K_batch,
BlockSize,
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
......@@ -290,7 +298,11 @@ struct DeviceGemmXdlSplitKCShuffleStatic
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
// GridwiseGemm
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2_static<
M_matrix,
N_matrix,
K_matrix,
K_batch,
BlockSize,
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
......@@ -444,7 +456,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting");
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2_static has invalid setting");
}
const index_t grid_size =
......@@ -490,13 +502,35 @@ struct DeviceGemmXdlSplitKCShuffleStatic
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
// check validaty when using splitk
hipGetErrorString(hipMemset(
arg.p_c_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
launch_and_time_kernel({stream_config.stream_id_, false},
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
};
if(has_main_k0_block_loop)
{
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
const auto kernel = kernel_gemm_xdlops_v2r4r2_static<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
......@@ -514,7 +548,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
const auto kernel = kernel_gemm_xdlops_v2r4r2_static<
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
......@@ -535,7 +569,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic
{
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
const auto kernel = kernel_gemm_xdlops_v2r4r2_static<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
......@@ -553,7 +587,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
const auto kernel = kernel_gemm_xdlops_v2r4r2_static<
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
......
......@@ -241,7 +241,7 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
// 2D slices of column-vectors in 3D space
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N, index_t K_batch>
struct BlockToCTileMap_KSplit_M00_N0_M01Adapt_Static
{
static constexpr auto I0 = Number<0>{};
......
......@@ -77,7 +77,11 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <index_t BlockSize,
template <index_t M_matrix,
index_t N_matrix,
index_t K_matrix,
index_t K_batch,
index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
......@@ -289,7 +293,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptorStatic(
const CMNGridDesc& c_m_n_grid_desc)
{
return BlockToCTileMap_KSplit_M00_N0_M01Adapt_Static<MPerBlock, NPerBlock, CMNGridDesc>(
return BlockToCTileMap_KSplit_M00_N0_M01Adapt_Static<MPerBlock, NPerBlock, CMNGridDesc, K_batch>(
c_m_n_grid_desc);
}
......
......@@ -50,7 +50,7 @@
#define USEING_STATIC_KERNEL 1
#define MNKB_0_8 1
#define MNKB_0_8 0
#define MNKB_1_4 0
#define MNKB_2_8 0
#define MNKB_3_5 0
......@@ -60,23 +60,23 @@
#if MNKB_0_8
#define M_matrix 16
#define N_matrix 4096
#define K_matrix 12800
#define K_batch 5
#define N_matrix 1152
#define K_matrix 5120
#define K_batch 8
#elif MNKB_1_4
#define M_matrix 16
#define N_matrix 4096
#define K_matrix 12800
#define K_batch 5
#define N_matrix 5120
#define K_matrix 384
#define K_batch 4
#elif MNKB_2_8
#define M_matrix 16
#define N_matrix 4096
#define K_matrix 12800
#define K_batch 5
#define N_matrix 1280
#define K_matrix 5120
#define K_batch 8
#elif MNKB_3_5
#define M_matrix 16
#define N_matrix 4096
#define K_matrix 12800
#define N_matrix 5120
#define K_matrix 1280
#define K_batch 5
#elif MNKB_4_5
#define M_matrix 16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment