Commit ea3ff2b7 authored by Adam Osewski's avatar Adam Osewski
Browse files

Use LocalBlockToCTile map in device ops.

parent 5ba70c28
...@@ -162,7 +162,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -162,7 +162,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType))); hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg, b2c_map); stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
}; };
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
......
...@@ -61,10 +61,12 @@ __global__ void ...@@ -61,10 +61,12 @@ __global__ void
group_id = index_t((left + right) / 2); group_id = index_t((left + right) / 2);
} }
LocalBlockToCTileMap<typename GemmDesc::B2CType> local_b2c{
gemm_desc_ptr[group_id].block_2_ctile_map_,
block_id - gemm_desc_ptr[group_id].block_start_};
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_, gemm_desc_ptr[group_id].karg_, static_cast<void*>(p_shared), local_b2c);
static_cast<void*>(p_shared),
gemm_desc_ptr[group_id].block_2_ctile_map_);
#else #else
ignore = gemm_descs_const; ignore = gemm_descs_const;
ignore = group_count; ignore = group_count;
...@@ -189,18 +191,20 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -189,18 +191,20 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>; BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
// Block2CTileMap configuration parameter. // Block2CTileMap configuration parameter.
static constexpr index_t B2E_M01 = 8; static constexpr index_t B2E_M01 = 8;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>; // using GroupedGemmBlock2ETileMap = LocalBlockToCTileMap<Block2ETileMapKSplit>;
using KernelArgument = typename GridwiseGemm::Argument; using KernelArgument = typename GridwiseGemm::Argument;
struct GemmTransKernelArg struct GemmTransKernelArg
{ {
using B2CType = Block2ETileMapKSplit;
KernelArgument karg_; KernelArgument karg_;
GroupedGemmBlock2ETileMap block_2_ctile_map_; Block2ETileMapKSplit block_2_ctile_map_;
index_t block_start_, block_end_; index_t block_start_, block_end_;
GemmTransKernelArg() = default; GemmTransKernelArg() = default;
GemmTransKernelArg(KernelArgument&& karg, GemmTransKernelArg(KernelArgument&& karg,
GroupedGemmBlock2ETileMap&& b2c_map, Block2ETileMapKSplit&& b2c_map,
index_t block_start, index_t block_start,
index_t block_end) index_t block_end)
: karg_{karg}, : karg_{karg},
...@@ -270,8 +274,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -270,8 +274,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c); const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c);
const auto local_b2c_tile_map = auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = grid_size_; const index_t block_start = grid_size_;
...@@ -279,10 +282,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -279,10 +282,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
// block-to-e-tile map
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
auto karg = KernelArgument{type_convert<const ADataType*>(p_As[i]), auto karg = KernelArgument{type_convert<const ADataType*>(p_As[i]),
type_convert<const BDataType*>(p_Bs[i]), type_convert<const BDataType*>(p_Bs[i]),
type_convert<EDataType*>(p_Es[i]), type_convert<EDataType*>(p_Es[i]),
...@@ -299,7 +298,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -299,7 +298,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
K_BATCH}; K_BATCH};
gemm_kernel_args_.emplace_back( gemm_kernel_args_.emplace_back(
std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); std::move(karg), std::move(local_b2c_tile_map), block_start, block_end);
} }
} }
...@@ -324,8 +323,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -324,8 +323,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const auto c_grid_desc_m_n = const auto c_grid_desc_m_n =
GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto local_b2c_tile_map = auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = grid_size_; const index_t block_start = grid_size_;
...@@ -333,14 +331,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -333,14 +331,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
// block-to-e-tile map
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg.KPadded = k_padded; karg.KPadded = k_padded;
karg.K0 = k0; karg.K0 = k0;
karg.k_batch = K_BATCH; karg.k_batch = K_BATCH;
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; gemm_kernel_args_[i].block_2_ctile_map_ = local_b2c_tile_map;
gemm_kernel_args_[i].block_start_ = block_start; gemm_kernel_args_[i].block_start_ = block_start;
gemm_kernel_args_[i].block_end_ = block_end; gemm_kernel_args_[i].block_end_ = block_end;
} }
......
...@@ -27,8 +27,7 @@ __global__ void ...@@ -27,8 +27,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg)
const Block2CTileMap& b2c_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -36,11 +35,12 @@ __global__ void ...@@ -36,11 +35,12 @@ __global__ void
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
Block2CTileMap b2c_map{get_block_1d_id()};
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
karg, static_cast<void*>(p_shared), b2c_map); karg, static_cast<void*>(p_shared), b2c_map);
#else #else
ignore = karg; ignore = karg;
ignore = b2c_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -601,7 +601,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -601,7 +601,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// divide block work by [KBatch, M, N] // divide block work by [KBatch, M, N]
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); // block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
block_2_ctile_map.CalculateBottomIndex();
if(!block_2_ctile_map.ValidCTileIndex( if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx, block_work_idx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment