Commit 8bfacf9f authored by Jing Zhang's avatar Jing Zhang Committed by root
Browse files

move block2tile into kernel

parent cf9bcb31
...@@ -60,9 +60,6 @@ __global__ void ...@@ -60,9 +60,6 @@ __global__ void
// const auto K0 = gemm_shared_args.KPadded; // const auto K0 = gemm_shared_args.KPadded;
// const auto k_batch = gemm_shared_args.k_batch; // const auto k_batch = gemm_shared_args.k_batch;
// M = 2 N = 768 K = 4608 StrideA = 4608 StrideB = 4608 StrideC = 768 MPadded = 32 NPadded = 768
// KPadded = 4608 K0 = 576 k_batch = 1
const auto M = 2; const auto M = 2;
const auto N = 768; const auto N = 768;
const auto K = 4608; const auto K = 4608;
...@@ -75,7 +72,22 @@ __global__ void ...@@ -75,7 +72,22 @@ __global__ void
const auto K0 = 576; const auto K0 = 576;
const auto k_batch = 1; const auto k_batch = 1;
// const auto block_2_ctile_map = gemm_shared_args.block_2_ctile_map; static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock();
static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock();
static constexpr index_t B2E_M01 = 8;
const index_t block_start = gemm_shared_args.block_size * group_id;
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC);
const auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch};
auto grouped_block_2_ctile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
const auto block_2_ctile_map = grouped_block_2_ctile_map;
#endif #endif
...@@ -95,7 +107,7 @@ __global__ void ...@@ -95,7 +107,7 @@ __global__ void
K0, K0,
k_batch, k_batch,
static_cast<void*>(p_shared), static_cast<void*>(p_shared),
gemm_desc_ptr[group_id].karg_.block_2_ctile_map); block_2_ctile_map);
#else #else
ignore = gemm_descs_const; ignore = gemm_descs_const;
ignore = all_gemm_block_size; ignore = all_gemm_block_size;
...@@ -533,7 +545,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -533,7 +545,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
// index_t StrideA; // index_t StrideA;
// index_t StrideC; // index_t StrideC;
// index_t MPadded; // index_t MPadded;
GroupedGemmBlock2ETileMap block_2_ctile_map; // GroupedGemmBlock2ETileMap block_2_ctile_map;
}; };
struct GemmTransKernelArgMsN1K1 struct GemmTransKernelArgMsN1K1
...@@ -549,10 +561,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -549,10 +561,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
for(const auto& trans_arg : arg.gemm_kernel_args_) for(const auto& trans_arg : arg.gemm_kernel_args_)
{ {
auto karg = ArgumentMsN1K1{trans_arg.karg_.p_a_grid, auto karg = ArgumentMsN1K1{
trans_arg.karg_.p_b_grid, trans_arg.karg_.p_a_grid, trans_arg.karg_.p_b_grid, trans_arg.karg_.p_c_grid};
trans_arg.karg_.p_c_grid,
trans_arg.block_2_ctile_map_};
// auto block_size = trans_arg.block_end_ - trans_arg.block_start_; // auto block_size = trans_arg.block_end_ - trans_arg.block_start_;
// std::cout << "trans_arg.block_start_: " << trans_arg.block_start_ // std::cout << "trans_arg.block_start_: " << trans_arg.block_start_
......
...@@ -551,7 +551,8 @@ struct OffsettedBlockToCTileMap ...@@ -551,7 +551,8 @@ struct OffsettedBlockToCTileMap
{ {
using underlying_type = UnderlyingBlockToCTileMap; using underlying_type = UnderlyingBlockToCTileMap;
OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start) __host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t block_start)
{ {
block_to_ctile_map_ = block_to_ctile_map; block_to_ctile_map_ = block_to_ctile_map;
block_start_ = block_start; block_start_ = block_start;
......
...@@ -1090,6 +1090,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -1090,6 +1090,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
block_2_ctile_map); block_2_ctile_map);
} }
static constexpr auto GetMPerBlock() { return MPerBlock; }
static constexpr auto GetNPerBlock() { return NPerBlock; }
static std::string GetTypeString() static std::string GetTypeString()
{ {
auto str = std::stringstream(); auto str = std::stringstream();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment