Commit fee2002c authored by Jing Zhang's avatar Jing Zhang
Browse files

add b2c_tile_map

parent 326d6bc6
......@@ -20,6 +20,24 @@ struct GemmDesc
std::vector<ck::index_t> stride_Ds_;
};
template <index_t NumDTensor = 0>
struct GroupedGemmKernelArgument
{
const void* p_a_grid;
const void* p_b_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
......
......@@ -25,7 +25,12 @@ namespace device {
template <typename GridwiseGemm,
typename GemmDesc,
GemmSpecialization GemmSpec,
typename Block2ETileMapKSplit,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename Block2ETileMap,
typename GroupedGemmBlock2ETileMap,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
......@@ -98,11 +103,12 @@ __global__ void
if(M == 0 || N == 0 || K == 0)
return;
const index_t StrideA = K;
const index_t StrideB = K;
const index_t StrideDs[] = {};
const index_t StrideE = N;
const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB;
const auto StrideDs = gemm_desc_ptr[group_id].StrideDs;
const auto StrideE = gemm_desc_ptr[group_id].StrideE;
#if 0
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
......@@ -110,34 +116,60 @@ __global__ void
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
#endif
using DsDataType = ck::Tuple<>;
const auto e_grid_desc_m_n =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
const index_t BlockStart = group_id * grid_size_grp;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n};
const auto local_b2e_tile_map = Block2ETileMapKSplit{e_grid_desc_m_n};
const auto block_2_etile_map = GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart);
constexpr auto NumDTensor = 0;
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
DsGridPointer p_ds_grid_;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
});
auto m_loops = local_b2e_tile_map.CalculateMLoops();
index_t m_id = 0;
do
{
const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, m_id);
GridwiseGemm::
template Run<HasMainKBlockLoop, GemmSpec, ALayout, BLayout, DsLayout, ELayout>(
gemm_desc_ptr[group_id].p_a_grid,
gemm_desc_ptr[group_id].p_b_grid,
p_ds_grid_,
gemm_desc_ptr[group_id].p_e_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
block_2_etile_map);
m_id += 1;
} while(m_id < m_loops);
GridwiseGemm::template Run<HasMainKBlockLoop, GemmSpec, ALayout, BLayout, DsLayout, ELayout>(
gemm_desc_ptr[group_id].a_ptr_,
gemm_desc_ptr[group_id].b_ptr_,
gemm_desc_ptr[group_id].ds_ptr_,
gemm_desc_ptr[group_id].e_ptr_,
p_shared,
a_element_op,
b_element_op,
c_element_op,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
block_2_etile_map);
#endif
#else
......@@ -342,18 +374,162 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using Block2ETileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>;
template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMapMLoops
{
using underlying_type = UnderlyingBlockToCTileMap;
__host__ __device__
OffsettedBlockToCTileMapMLoops(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t block_start,
index_t mblock_id_off = 0)
{
block_to_ctile_map_ = block_to_ctile_map;
block_start_ = block_start;
mblock_id_off_ = mblock_id_off;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto idx_bot = block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] - block_start_));
return make_tuple(idx_bot[Number<0>{}] + mblock_id_off_, idx_bot[Number<1>{}]);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
template <typename CGridDesc_M_N>
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
}
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t block_start_;
index_t mblock_id_off_;
};
template <index_t MPerBlock_, index_t NPerBlock_>
struct BlockToCTileMap_M00_N0_M01Adapt_MLoops
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops() = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops(
const BlockToCTileMap_M00_N0_M01Adapt_MLoops&) = default;
__host__ __device__
BlockToCTileMap_M00_N0_M01Adapt_MLoops(BlockToCTileMap_M00_N0_M01Adapt_MLoops&&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops&
operator=(const BlockToCTileMap_M00_N0_M01Adapt_MLoops&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops&
operator=(BlockToCTileMap_M00_N0_M01Adapt_MLoops&&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops(index_t M,
index_t N,
index_t M01 = 8)
: M_(M), N_(N), M01_(M01)
{
}
template <typename CGridDesc_M_N>
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops(
const CGridDesc_M_N& c_grid_desc_m_n, index_t M01 = 8)
: BlockToCTileMap_M00_N0_M01Adapt_MLoops(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
{
}
__host__ __device__ constexpr index_t CalculateMLoops() const
{
return math::integer_divide_ceil(M_, MPerBlock_);
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return M0 * N0;
}
template <typename CGridDesc_M_N>
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = 1; // math::integer_divide_ceil(M_, MPerBlock_);
const auto N0 = math::integer_divide_ceil(N_, NPerBlock_);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
index_t idx_N0 = block_1d_id % N0;
index_t idx_M0 = block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
index_t idx_M00 = idx_M0 / M01_;
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t M_;
index_t N_;
index_t M01_;
};
using Block2ETileMap = BlockToCTileMap_M00_N0_M01Adapt_MLoops<MPerBlock, NPerBlock>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
struct GemmBiasTransKernelArg
{
// pointers
const ADataType* a_ptr_;
const BDataType* b_ptr_;
typename GridwiseGemm::DsGridPointer ds_ptr_;
EDataType* e_ptr_;
const void* a_ptr_;
const void* b_ptr_;
std::array<const void*, NumDTensor> ds_ptr_;
void* e_ptr_;
index_t M, N, K;
index_t M_, N_, K_;
index_t StrideA_, StrideB_;
std::array<index_t, NumDTensor> StrideDs_;
index_t StrideE_;
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
......@@ -415,12 +591,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
const index_t StrideC = gemm_descs[i].stride_C_;
// pointer
typename GridwiseGemm::DsGridPointer p_ds_grid{};
std::array<const void*, NumDTensor> p_ds_grid;
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
p_ds_grid(j) = static_cast<const DDataType*>(p_Ds[i][j]);
p_ds_grid[j] = static_cast<const DDataType*>(p_Ds[i][j]);
});
// tensor descriptors for problem definiton
......@@ -436,9 +612,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
M, N, gemm_descs[i].stride_Ds_[j]);
});
const auto e_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideC);
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
......@@ -446,6 +619,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
const auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
const auto e_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideC);
// block-to-e-tile map
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n};
......@@ -479,13 +655,17 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
e_grid_desc_m_n);
gemm_desc_kernel_arg_.push_back(
GemmBiasTransKernelArg{static_cast<const ADataType*>(p_As[i]),
static_cast<const BDataType*>(p_Bs[i]),
GemmBiasTransKernelArg{p_As[i],
p_Bs[i],
p_ds_grid,
static_cast<EDataType*>(p_Es[i]),
p_Es[i],
M,
N,
K,
StrideA,
StrideB,
{},
StrideC,
a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
......@@ -526,6 +706,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
{
bool has_main_k_block_loop = true;
std::vector<GroupedGemmKernelArgument<NumDTensor>> grouped_gemm_kernel_args;
grouped_gemm_kernel_args.reserve(arg.gemm_desc_kernel_arg_.size());
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
#if DEBUG_LOG
......@@ -568,12 +752,25 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
{
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
}
grouped_gemm_kernel_args.push_back(
GroupedGemmKernelArgument<NumDTensor>{arg.gemm_desc_kernel_arg_[i].a_ptr_,
arg.gemm_desc_kernel_arg_[i].b_ptr_,
{},
arg.gemm_desc_kernel_arg_[i].e_ptr_,
arg.gemm_desc_kernel_arg_[i].M_,
arg.gemm_desc_kernel_arg_[i].N_,
arg.gemm_desc_kernel_arg_[i].K_,
arg.gemm_desc_kernel_arg_[i].StrideA_,
arg.gemm_desc_kernel_arg_[i].StrideB_,
arg.gemm_desc_kernel_arg_[i].StrideDs_,
arg.gemm_desc_kernel_arg_[i].StrideE_});
}
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() *
sizeof(GemmBiasTransKernelArg),
grouped_gemm_kernel_args.data(),
grouped_gemm_kernel_args.size() *
sizeof(GroupedGemmKernelArgument<NumDTensor>),
hipMemcpyHostToDevice,
stream_config.stream_id_));
......@@ -581,9 +778,14 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm,
GemmBiasTransKernelArg,
GroupedGemmKernelArgument<NumDTensor>,
GemmSpec,
ALayout,
BLayout,
DsLayout,
ELayout,
Block2ETileMap,
GroupedGemmBlock2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
......
......@@ -425,6 +425,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
Number<NumDTensor>{});
}
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
......@@ -868,10 +870,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
#endif
typename Block2ETileMap>
__device__ static void Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
__device__ static void Run(const void* __restrict__ p_a_grid_,
const void* __restrict__ p_b_grid_,
DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
void* __restrict__ p_e_grid_,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
......@@ -881,7 +883,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const index_t K,
const index_t StrideA,
const index_t StrideB,
const index_t StrideDs[],
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
#if 0
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
......@@ -893,6 +895,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
#endif
const Block2ETileMap& block_2_etile_map)
{
const auto p_a_grid = reinterpret_cast<const ABDataType*>(p_a_grid_);
const auto p_b_grid = reinterpret_cast<const ABDataType*>(p_b_grid_);
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
// tensor descriptors for problem definiton
const auto a_grid_desc_m_k = MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment