Commit f4f94f70 authored by Jing Zhang's avatar Jing Zhang
Browse files

merge group and non-group

parent bb9c4a89
......@@ -10,7 +10,7 @@
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_grouped_gemm_xdlops_v2r3.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
#include "gemm_specialization.hpp"
namespace ck {
......@@ -182,7 +182,7 @@ struct DeviceGroupedGemmXdl
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGroupedGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
......
......@@ -54,6 +54,82 @@ __global__ void
block_2_ctile_map);
}
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
bool HasMainK0BlockLoop,
index_t MaxGroupCount>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r3(
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_desc_,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
#if 1
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd &&
i < group_count)
{
auto group_id = i;
const index_t block_id_grp = block_id - gemm_desc_[group_id].BlockStart;
GridwiseGemm::template Run<HasMainK0BlockLoop>(
gemm_desc_[group_id].a_ptr,
gemm_desc_[group_id].b_ptr,
gemm_desc_[group_id].c_ptr,
p_shared,
gemm_desc_[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_[group_id].block_2_ctile_map_,
block_id_grp);
}
});
#else
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_desc_);
index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd &&
i < group_count)
? i
: group_id;
});
const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart;
GridwiseGemm::template Run<HasMainK0BlockLoop>(
gemm_desc_ptr[group_id].a_ptr,
gemm_desc_ptr[group_id].b_ptr,
gemm_desc_ptr[group_id].c_ptr,
p_shared,
gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_ptr[group_id].block_2_ctile_map_,
block_id_grp);
#endif
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
......@@ -350,7 +426,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map)
const Block2CTileMap& block_2_ctile_map,
ck::index_t block_id = get_block_1d_id())
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
......@@ -363,7 +440,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_id));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
......
#ifndef CK_GRIDWISE_GROUPED_GEMM_XDLOPS_V2R3_HPP
#define CK_GRIDWISE_GROUPED_GEMM_XDLOPS_V2R3_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
bool HasMainK0BlockLoop,
index_t MaxGroupCount>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r3(
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_desc_,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
#if 1
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd &&
i < group_count)
{
auto group_id = i;
const index_t block_id_grp = block_id - gemm_desc_[group_id].BlockStart;
GridwiseGemm::template Run<HasMainK0BlockLoop>(
gemm_desc_[group_id].a_ptr,
gemm_desc_[group_id].b_ptr,
gemm_desc_[group_id].c_ptr,
p_shared,
gemm_desc_[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_[group_id].block_2_ctile_map_,
block_id_grp);
}
});
#else
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_desc_);
index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd &&
i < group_count)
? i
: group_id;
});
const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart;
GridwiseGemm::template Run<HasMainK0BlockLoop>(
gemm_desc_ptr[group_id].a_ptr,
gemm_desc_ptr[group_id].b_ptr,
gemm_desc_ptr[group_id].c_ptr,
p_shared,
gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_ptr[group_id].block_2_ctile_map_,
block_id_grp);
#endif
}
#if 0
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap,
bool HasMainK0BlockLoop,
index_t MaxGroupCount>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r4(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const StaticallyIndexedArray<AGridDesc_K0_M_K1, MaxGroupCount> a_grid_desc_k0_m_k1,
const StaticallyIndexedArray<BGridDesc_K0_N_K1, MaxGroupCount> b_grid_desc_k0_n_k1,
const StaticallyIndexedArray<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, MaxGroupCount>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_shapes,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const StaticallyIndexedArray<Block2CTileMap, MaxGroupCount> block_2_ctile_map)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
__shared__ AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_[MaxGroupCount];
__shared__ BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_[MaxGroupCount];
__shared__ CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_[MaxGroupCount];
__shared__ Block2CTileMap block_2_ctile_map_[MaxGroupCount];
__shared__ GemmDesc gemm_shapes_[MaxGroupCount];
if(get_thread_local_1d_id())
{
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
a_grid_desc_k0_m_k1_[i] = a_grid_desc_k0_m_k1[i];
b_grid_desc_k0_n_k1_[i] = b_grid_desc_k0_n_k1[i];
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_[i] = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2[i];
block_2_ctile_map_[i] = block_2_ctile_map[i];
gemm_shapes_[i] = gemm_shapes[i];
});
}
block_sync_lds();
index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_shapes[i].BlockStart &&
block_id < (gemm_shapes[i].BlockStart + gemm_shapes[i].BlockSize))
? i
: group_id;
});
const index_t block_id_grp = block_id - gemm_shapes_[group_id].BlockStart;
const index_t a_offset_grp = gemm_shapes_[group_id].OffsetA;
const index_t b_offset_grp = gemm_shapes_[group_id].OffsetB;
const index_t c_offset_grp = gemm_shapes_[group_id].OffsetC;
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_offset_grp,
p_b_grid + b_offset_grp,
p_c_grid + c_offset_grp,
p_shared,
a_grid_desc_k0_m_k1_[group_id],
b_grid_desc_k0_n_k1_[group_id],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_[group_id],
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map_[group_id],
block_id_grp);
}
#endif
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t K1Value,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
bool BBlockLdsExtraN,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
index_t NumPrefetch = 1>
struct GridwiseGroupedGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
return a_block_desc_k0_m_k1;
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
{
constexpr auto max_lds_align = K1;
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
return b_block_desc_k0_n_k1;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned =
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01,
index_t N01)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;
// check NumPrefetch
if constexpr(NumPrefetch == 1)
{
// 1-stage prefetch always supported
}
else if constexpr(NumPrefetch == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
{
return false;
}
}
else
{
return false;
}
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false;
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
// TODO move this function into GEMM-pipeline class
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1;
return has_main_k0_block_loop;
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
K1>;
return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
}
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map,
const index_t block_id)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_id));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumPrefetch>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumPrefetch>(
b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
K1>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_k0_n_k1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_k0_m_k1)>,
remove_cvref_t<decltype(a_block_desc_k0_m_k1)>,
remove_cvref_t<decltype(a_blockwise_copy)>,
remove_cvref_t<decltype(a_grid_buf)>,
remove_cvref_t<decltype(a_block_buf)>,
remove_cvref_t<decltype(a_block_slice_copy_step)>,
remove_cvref_t<decltype(b_grid_desc_k0_n_k1)>,
remove_cvref_t<decltype(b_block_desc_k0_n_k1)>,
remove_cvref_t<decltype(b_blockwise_copy)>,
remove_cvref_t<decltype(b_grid_buf)>,
remove_cvref_t<decltype(b_block_buf)>,
remove_cvref_t<decltype(b_block_slice_copy_step)>,
remove_cvref_t<decltype(blockwise_gemm)>,
remove_cvref_t<decltype(c_thread_buf)>,
NumPrefetch,
HasMainK0BlockLoop>{};
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1,
a_block_desc_k0_m_k1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_k0_n_k1,
b_block_desc_k0_n_k1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
K0BlockMainLoop);
// output: register to global memory
{
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_grid_idx =
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_grid));
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_grid_idx =
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_grid));
auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
CElementwiseOperation,
Sequence<M0, N0, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3],
m_thread_data_on_grid_idx[I4],
n_thread_data_on_grid_idx[I2]),
c_element_op};
c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_grid_buf);
}
}
};
} // namespace ck
#endif
......@@ -69,10 +69,12 @@ void profile_grouped_gemm_impl(int do_verification,
}
};
if(!(Ms.size() == Ns.size() && Ns.size() == Ks.size() && Ks.size() == StrideAs.size() &&
StrideAs.size() == StrideBs.size() && StrideBs.size() == StrideCs.size()))
int group_count = Ms.size();
if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() &&
group_count == StrideBs.size() && group_count == StrideCs.size()))
{
throw std::runtime_error("wrong! inconsistent Ms, Ns, Ks, StrideA/B/Cs size\n");
throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n");
}
std::vector<Tensor<ADataType>> a_m_k;
......@@ -125,9 +127,22 @@ void profile_grouped_gemm_impl(int do_verification,
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_device_buf, b_device_buf, c_device_buf;
std::vector<GemmShape> gemm_shapes;
a_device_buf.reserve(group_count);
b_device_buf.reserve(group_count);
c_device_buf.reserve(group_count);
for(int i = 0; i < Ms.size(); i++)
std::vector<const void*> p_a, p_b;
std::vector<void*> p_c;
p_a.reserve(group_count);
p_b.reserve(group_count);
p_c.reserve(group_count);
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
gemm_shapes.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
a_device_buf.push_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSize()));
......@@ -141,15 +156,11 @@ void profile_grouped_gemm_impl(int do_verification,
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
c_device_buf[i]->ToDevice(c_m_n_device_results[i].mData.data());
gemm_shapes.push_back({Ms[i],
Ns[i],
Ks[i],
StrideAs[i],
StrideBs[i],
StrideCs[i],
a_device_buf[i]->GetDeviceBuffer(),
b_device_buf[i]->GetDeviceBuffer(),
c_device_buf[i]->GetDeviceBuffer()});
gemm_shapes.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i]});
p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
p_c.push_back(c_device_buf[i]->GetDeviceBuffer());
}
// add device GEMM instances
......@@ -204,7 +215,10 @@ void profile_grouped_gemm_impl(int do_verification,
for(auto& gemm_ptr : gemm_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(gemm_shapes,
gemm_ptr->MakeArgumentPointer(p_a,
p_b,
p_c,
gemm_shapes,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment