Commit fee2002c authored by Jing Zhang's avatar Jing Zhang
Browse files

add b2c_tile_map

parent 326d6bc6
...@@ -20,6 +20,24 @@ struct GemmDesc ...@@ -20,6 +20,24 @@ struct GemmDesc
std::vector<ck::index_t> stride_Ds_; std::vector<ck::index_t> stride_Ds_;
}; };
template <index_t NumDTensor = 0>
struct GroupedGemmKernelArgument
{
const void* p_a_grid;
const void* p_b_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
};
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DsLayout, typename DsLayout,
......
...@@ -25,7 +25,12 @@ namespace device { ...@@ -25,7 +25,12 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename GemmDesc, typename GemmDesc,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
typename Block2ETileMapKSplit, typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename Block2ETileMap,
typename GroupedGemmBlock2ETileMap,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
...@@ -98,11 +103,12 @@ __global__ void ...@@ -98,11 +103,12 @@ __global__ void
if(M == 0 || N == 0 || K == 0) if(M == 0 || N == 0 || K == 0)
return; return;
const index_t StrideA = K; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const index_t StrideB = K; const auto StrideB = gemm_desc_ptr[group_id].StrideB;
const index_t StrideDs[] = {}; const auto StrideDs = gemm_desc_ptr[group_id].StrideDs;
const index_t StrideE = N; const auto StrideE = gemm_desc_ptr[group_id].StrideE;
#if 0
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -110,34 +116,60 @@ __global__ void ...@@ -110,34 +116,60 @@ __global__ void
using BLayout = Col; using BLayout = Col;
using DsLayout = ck::Tuple<>; using DsLayout = ck::Tuple<>;
using ELayout = Row; using ELayout = Row;
#endif
using DsDataType = ck::Tuple<>;
const auto e_grid_desc_m_n = const auto e_grid_desc_m_n =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE); GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
const index_t BlockStart = group_id * grid_size_grp; const index_t BlockStart = group_id * grid_size_grp;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>; const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n};
const auto local_b2e_tile_map = Block2ETileMapKSplit{e_grid_desc_m_n}; constexpr auto NumDTensor = 0;
const auto block_2_etile_map = GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart);
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
DsGridPointer p_ds_grid_;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
});
auto m_loops = local_b2e_tile_map.CalculateMLoops();
index_t m_id = 0;
do
{
const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, m_id);
GridwiseGemm::
template Run<HasMainKBlockLoop, GemmSpec, ALayout, BLayout, DsLayout, ELayout>(
gemm_desc_ptr[group_id].p_a_grid,
gemm_desc_ptr[group_id].p_b_grid,
p_ds_grid_,
gemm_desc_ptr[group_id].p_e_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
block_2_etile_map);
m_id += 1;
} while(m_id < m_loops);
GridwiseGemm::template Run<HasMainKBlockLoop, GemmSpec, ALayout, BLayout, DsLayout, ELayout>(
gemm_desc_ptr[group_id].a_ptr_,
gemm_desc_ptr[group_id].b_ptr_,
gemm_desc_ptr[group_id].ds_ptr_,
gemm_desc_ptr[group_id].e_ptr_,
p_shared,
a_element_op,
b_element_op,
c_element_op,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
block_2_etile_map);
#endif #endif
#else #else
...@@ -342,18 +374,162 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -342,18 +374,162 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using Block2ETileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>; template <typename UnderlyingBlockToCTileMap>
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>; struct OffsettedBlockToCTileMapMLoops
{
using underlying_type = UnderlyingBlockToCTileMap;
__host__ __device__
OffsettedBlockToCTileMapMLoops(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t block_start,
index_t mblock_id_off = 0)
{
block_to_ctile_map_ = block_to_ctile_map;
block_start_ = block_start;
mblock_id_off_ = mblock_id_off;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto idx_bot = block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] - block_start_));
return make_tuple(idx_bot[Number<0>{}] + mblock_id_off_, idx_bot[Number<1>{}]);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
template <typename CGridDesc_M_N>
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
}
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t block_start_;
index_t mblock_id_off_;
};
template <index_t MPerBlock_, index_t NPerBlock_>
struct BlockToCTileMap_M00_N0_M01Adapt_MLoops
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops() = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops(
const BlockToCTileMap_M00_N0_M01Adapt_MLoops&) = default;
__host__ __device__
BlockToCTileMap_M00_N0_M01Adapt_MLoops(BlockToCTileMap_M00_N0_M01Adapt_MLoops&&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops&
operator=(const BlockToCTileMap_M00_N0_M01Adapt_MLoops&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops&
operator=(BlockToCTileMap_M00_N0_M01Adapt_MLoops&&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops(index_t M,
index_t N,
index_t M01 = 8)
: M_(M), N_(N), M01_(M01)
{
}
template <typename CGridDesc_M_N>
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops(
const CGridDesc_M_N& c_grid_desc_m_n, index_t M01 = 8)
: BlockToCTileMap_M00_N0_M01Adapt_MLoops(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
{
}
__host__ __device__ constexpr index_t CalculateMLoops() const
{
return math::integer_divide_ceil(M_, MPerBlock_);
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return M0 * N0;
}
template <typename CGridDesc_M_N>
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = 1; // math::integer_divide_ceil(M_, MPerBlock_);
const auto N0 = math::integer_divide_ceil(N_, NPerBlock_);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
index_t idx_N0 = block_1d_id % N0;
index_t idx_M0 = block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
index_t idx_M00 = idx_M0 / M01_;
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t M_;
index_t N_;
index_t M01_;
};
using Block2ETileMap = BlockToCTileMap_M00_N0_M01Adapt_MLoops<MPerBlock, NPerBlock>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
struct GemmBiasTransKernelArg struct GemmBiasTransKernelArg
{ {
// pointers // pointers
const ADataType* a_ptr_; const void* a_ptr_;
const BDataType* b_ptr_; const void* b_ptr_;
typename GridwiseGemm::DsGridPointer ds_ptr_; std::array<const void*, NumDTensor> ds_ptr_;
EDataType* e_ptr_; void* e_ptr_;
index_t M, N, K; index_t M_, N_, K_;
index_t StrideA_, StrideB_;
std::array<index_t, NumDTensor> StrideDs_;
index_t StrideE_;
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
...@@ -415,12 +591,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -415,12 +591,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
const index_t StrideC = gemm_descs[i].stride_C_; const index_t StrideC = gemm_descs[i].stride_C_;
// pointer // pointer
typename GridwiseGemm::DsGridPointer p_ds_grid{}; std::array<const void*, NumDTensor> p_ds_grid;
static_for<0, NumDTensor, 1>{}([&](auto j) { static_for<0, NumDTensor, 1>{}([&](auto j) {
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>; using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
p_ds_grid(j) = static_cast<const DDataType*>(p_Ds[i][j]); p_ds_grid[j] = static_cast<const DDataType*>(p_Ds[i][j]);
}); });
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
...@@ -436,9 +612,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -436,9 +612,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
M, N, gemm_descs[i].stride_Ds_[j]); M, N, gemm_descs[i].stride_Ds_[j]);
}); });
const auto e_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideC);
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
const auto a_grid_desc_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
...@@ -446,6 +619,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -446,6 +619,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
const auto e_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideC);
// block-to-e-tile map // block-to-e-tile map
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n}; const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n};
...@@ -479,13 +655,17 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -479,13 +655,17 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
e_grid_desc_m_n); e_grid_desc_m_n);
gemm_desc_kernel_arg_.push_back( gemm_desc_kernel_arg_.push_back(
GemmBiasTransKernelArg{static_cast<const ADataType*>(p_As[i]), GemmBiasTransKernelArg{p_As[i],
static_cast<const BDataType*>(p_Bs[i]), p_Bs[i],
p_ds_grid, p_ds_grid,
static_cast<EDataType*>(p_Es[i]), p_Es[i],
M, M,
N, N,
K, K,
StrideA,
StrideB,
{},
StrideC,
a_grid_desc_m_k, a_grid_desc_m_k,
b_grid_desc_n_k, b_grid_desc_n_k,
ds_grid_desc_m_n, ds_grid_desc_m_n,
...@@ -526,6 +706,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -526,6 +706,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
{ {
bool has_main_k_block_loop = true; bool has_main_k_block_loop = true;
std::vector<GroupedGemmKernelArgument<NumDTensor>> grouped_gemm_kernel_args;
grouped_gemm_kernel_args.reserve(arg.gemm_desc_kernel_arg_.size());
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{ {
#if DEBUG_LOG #if DEBUG_LOG
...@@ -568,12 +752,25 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -568,12 +752,25 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
{ {
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
} }
grouped_gemm_kernel_args.push_back(
GroupedGemmKernelArgument<NumDTensor>{arg.gemm_desc_kernel_arg_[i].a_ptr_,
arg.gemm_desc_kernel_arg_[i].b_ptr_,
{},
arg.gemm_desc_kernel_arg_[i].e_ptr_,
arg.gemm_desc_kernel_arg_[i].M_,
arg.gemm_desc_kernel_arg_[i].N_,
arg.gemm_desc_kernel_arg_[i].K_,
arg.gemm_desc_kernel_arg_[i].StrideA_,
arg.gemm_desc_kernel_arg_[i].StrideB_,
arg.gemm_desc_kernel_arg_[i].StrideDs_,
arg.gemm_desc_kernel_arg_[i].StrideE_});
} }
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_, hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(), grouped_gemm_kernel_args.data(),
arg.gemm_desc_kernel_arg_.size() * grouped_gemm_kernel_args.size() *
sizeof(GemmBiasTransKernelArg), sizeof(GroupedGemmKernelArgument<NumDTensor>),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
stream_config.stream_id_)); stream_config.stream_id_));
...@@ -581,9 +778,14 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -581,9 +778,14 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm, const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm,
GemmBiasTransKernelArg, GroupedGemmKernelArgument<NumDTensor>,
GemmSpec, GemmSpec,
ALayout,
BLayout,
DsLayout,
ELayout,
Block2ETileMap, Block2ETileMap,
GroupedGemmBlock2ETileMap,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
......
...@@ -425,6 +425,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -425,6 +425,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
...@@ -868,10 +870,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -868,10 +870,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
#endif #endif
typename Block2ETileMap> typename Block2ETileMap>
__device__ static void Run(const ABDataType* __restrict__ p_a_grid, __device__ static void Run(const void* __restrict__ p_a_grid_,
const ABDataType* __restrict__ p_b_grid, const void* __restrict__ p_b_grid_,
DsGridPointer p_ds_grid, DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, void* __restrict__ p_e_grid_,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
...@@ -881,7 +883,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -881,7 +883,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const index_t K, const index_t K,
const index_t StrideA, const index_t StrideA,
const index_t StrideB, const index_t StrideB,
const index_t StrideDs[], const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE, const index_t StrideE,
#if 0 #if 0
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
...@@ -893,6 +895,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -893,6 +895,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
#endif #endif
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
const auto p_a_grid = reinterpret_cast<const ABDataType*>(p_a_grid_);
const auto p_b_grid = reinterpret_cast<const ABDataType*>(p_b_grid_);
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
const auto a_grid_desc_m_k = MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA); const auto a_grid_desc_m_k = MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment