Commit ea3ff2b7 authored by Adam Osewski's avatar Adam Osewski
Browse files

Use LocalBlockToCTile map in device ops.

parent 5ba70c28
......@@ -162,7 +162,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg, b2c_map);
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
};
if(has_main_k0_block_loop)
......
......@@ -61,10 +61,12 @@ __global__ void
group_id = index_t((left + right) / 2);
}
LocalBlockToCTileMap<typename GemmDesc::B2CType> local_b2c{
gemm_desc_ptr[group_id].block_2_ctile_map_,
block_id - gemm_desc_ptr[group_id].block_start_};
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_,
static_cast<void*>(p_shared),
gemm_desc_ptr[group_id].block_2_ctile_map_);
gemm_desc_ptr[group_id].karg_, static_cast<void*>(p_shared), local_b2c);
#else
ignore = gemm_descs_const;
ignore = group_count;
......@@ -189,18 +191,20 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
// Block2CTileMap configuration parameter.
static constexpr index_t B2E_M01 = 8;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
using KernelArgument = typename GridwiseGemm::Argument;
// using GroupedGemmBlock2ETileMap = LocalBlockToCTileMap<Block2ETileMapKSplit>;
using KernelArgument = typename GridwiseGemm::Argument;
struct GemmTransKernelArg
{
using B2CType = Block2ETileMapKSplit;
KernelArgument karg_;
GroupedGemmBlock2ETileMap block_2_ctile_map_;
Block2ETileMapKSplit block_2_ctile_map_;
index_t block_start_, block_end_;
GemmTransKernelArg() = default;
GemmTransKernelArg(KernelArgument&& karg,
GroupedGemmBlock2ETileMap&& b2c_map,
Block2ETileMapKSplit&& b2c_map,
index_t block_start,
index_t block_end)
: karg_{karg},
......@@ -270,8 +274,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c);
const auto local_b2c_tile_map =
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = grid_size_;
......@@ -279,10 +282,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
grid_size_ += grid_size_grp;
// block-to-e-tile map
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
auto karg = KernelArgument{type_convert<const ADataType*>(p_As[i]),
type_convert<const BDataType*>(p_Bs[i]),
type_convert<EDataType*>(p_Es[i]),
......@@ -299,7 +298,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
K_BATCH};
gemm_kernel_args_.emplace_back(
std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
std::move(karg), std::move(local_b2c_tile_map), block_start, block_end);
}
}
......@@ -324,8 +323,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const auto c_grid_desc_m_n =
GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto local_b2c_tile_map =
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = grid_size_;
......@@ -333,14 +331,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
grid_size_ += grid_size_grp;
// block-to-e-tile map
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg.KPadded = k_padded;
karg.K0 = k0;
karg.k_batch = K_BATCH;
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_[i].block_2_ctile_map_ = local_b2c_tile_map;
gemm_kernel_args_[i].block_start_ = block_start;
gemm_kernel_args_[i].block_end_ = block_end;
}
......
......@@ -27,8 +27,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg,
const Block2CTileMap& b2c_map)
kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......@@ -36,11 +35,12 @@ __global__ void
__shared__ uint8_t p_shared[shared_size];
Block2CTileMap b2c_map{get_block_1d_id()};
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
karg, static_cast<void*>(p_shared), b2c_map);
#else
ignore = karg;
ignore = b2c_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -601,7 +601,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// divide block work by [KBatch, M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
block_2_ctile_map.CalculateBottomIndex();
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment