"vscode:/vscode.git/clone" did not exist on "7967161abf66de3cf42eb44672d61e6d4a8ddf56"
Commit e3a4b967 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed mem issue with unique_ptr

parent 8fb2b172
#ifndef CK_GRIDWISE_GROUPED_GEMM_XDLOPS_V2R3_HPP
#define CK_GRIDWISE_GROUPED_GEMM_XDLOPS_V2R3_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
bool HasMainK0BlockLoop,
index_t MaxGroupCount>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r3(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_desc_,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd)
{
const index_t group_id = i;
const index_t block_id_grp = block_id - gemm_desc_[i].BlockStart;
const index_t a_offset_grp = gemm_desc_[i].OffsetA;
const index_t b_offset_grp = gemm_desc_[i].OffsetB;
const index_t c_offset_grp = gemm_desc_[i].OffsetC;
GridwiseGemm::template Run<HasMainK0BlockLoop>(
p_a_grid + a_offset_grp,
p_b_grid + b_offset_grp,
p_c_grid + c_offset_grp,
p_shared,
gemm_desc_[i].a_grid_desc_k0_m_k1_,
gemm_desc_[i].b_grid_desc_k0_n_k1_,
gemm_desc_[i].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_[i].block_2_ctile_map_,
block_id_grp);
}
});
}
#if 0
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap,
bool HasMainK0BlockLoop,
index_t MaxGroupCount>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r4(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const StaticallyIndexedArray<AGridDesc_K0_M_K1, MaxGroupCount> a_grid_desc_k0_m_k1,
const StaticallyIndexedArray<BGridDesc_K0_N_K1, MaxGroupCount> b_grid_desc_k0_n_k1,
const StaticallyIndexedArray<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, MaxGroupCount>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_shapes,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const StaticallyIndexedArray<Block2CTileMap, MaxGroupCount> block_2_ctile_map)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
__shared__ AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_[MaxGroupCount];
__shared__ BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_[MaxGroupCount];
__shared__ CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_[MaxGroupCount];
__shared__ Block2CTileMap block_2_ctile_map_[MaxGroupCount];
__shared__ GemmDesc gemm_shapes_[MaxGroupCount];
if(get_thread_local_1d_id())
{
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
a_grid_desc_k0_m_k1_[i] = a_grid_desc_k0_m_k1[i];
b_grid_desc_k0_n_k1_[i] = b_grid_desc_k0_n_k1[i];
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_[i] = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2[i];
block_2_ctile_map_[i] = block_2_ctile_map[i];
gemm_shapes_[i] = gemm_shapes[i];
});
}
block_sync_lds();
index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_shapes[i].BlockStart &&
block_id < (gemm_shapes[i].BlockStart + gemm_shapes[i].BlockSize))
? i
: group_id;
});
const index_t block_id_grp = block_id - gemm_shapes_[group_id].BlockStart;
const index_t a_offset_grp = gemm_shapes_[group_id].OffsetA;
const index_t b_offset_grp = gemm_shapes_[group_id].OffsetB;
const index_t c_offset_grp = gemm_shapes_[group_id].OffsetC;
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_offset_grp,
p_b_grid + b_offset_grp,
p_c_grid + c_offset_grp,
p_shared,
a_grid_desc_k0_m_k1_[group_id],
b_grid_desc_k0_n_k1_[group_id],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_[group_id],
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map_[group_id],
block_id_grp);
}
#endif
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t K1Value,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
bool BBlockLdsExtraN,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
index_t NumPrefetch = 1>
struct GridwiseGroupedGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
return a_block_desc_k0_m_k1;
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
{
constexpr auto max_lds_align = K1;
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
return b_block_desc_k0_n_k1;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned =
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01,
index_t N01)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;
// check NumPrefetch
if constexpr(NumPrefetch == 1)
{
// 1-stage prefetch always supported
}
else if constexpr(NumPrefetch == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
{
return false;
}
}
else
{
return false;
}
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false;
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
// TODO move this function into GEMM-pipeline class
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1;
return has_main_k0_block_loop;
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
K1>;
return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
}
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map,
const index_t block_id)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_id));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumPrefetch>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumPrefetch>(
b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
K1>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_k0_n_k1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_k0_m_k1)>,
remove_cvref_t<decltype(a_block_desc_k0_m_k1)>,
remove_cvref_t<decltype(a_blockwise_copy)>,
remove_cvref_t<decltype(a_grid_buf)>,
remove_cvref_t<decltype(a_block_buf)>,
remove_cvref_t<decltype(a_block_slice_copy_step)>,
remove_cvref_t<decltype(b_grid_desc_k0_n_k1)>,
remove_cvref_t<decltype(b_block_desc_k0_n_k1)>,
remove_cvref_t<decltype(b_blockwise_copy)>,
remove_cvref_t<decltype(b_grid_buf)>,
remove_cvref_t<decltype(b_block_buf)>,
remove_cvref_t<decltype(b_block_slice_copy_step)>,
remove_cvref_t<decltype(blockwise_gemm)>,
remove_cvref_t<decltype(c_thread_buf)>,
NumPrefetch,
HasMainK0BlockLoop>{};
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1,
a_block_desc_k0_m_k1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_k0_n_k1,
b_block_desc_k0_n_k1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
K0BlockMainLoop);
// output: register to global memory
{
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_grid_idx =
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_grid));
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_grid_idx =
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_grid));
auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
CElementwiseOperation,
Sequence<M0, N0, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3],
m_thread_data_on_grid_idx[I4],
n_thread_data_on_grid_idx[I2]),
c_element_op};
c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_grid_buf);
}
}
};
} // namespace ck
#endif
......@@ -81,19 +81,17 @@ int main(int argc, char* argv[])
// GEMM shape
std::vector<ck::GemmShape> gemm_shapes;
int A_size = 0, B_size = 0, C_size = 0;
for(int i = 0; i < group_count; i++)
{
int M = 256 + 256 * i;
int N = 128 + 128 * i;
int K = 64 + 64 * i;
// int M = 256 + 256 * i;
// int N = 128 + 128 * i;
// int K = 64 + 64 * i;
gemm_shapes.push_back({M, N, K, K, K, N, nullptr, nullptr, nullptr});
int M = 3840;
int N = 1024;
int K = 4096;
A_size += gemm_shapes[i].M * gemm_shapes[i].K;
B_size += gemm_shapes[i].N * gemm_shapes[i].K;
C_size += gemm_shapes[i].M * gemm_shapes[i].N;
gemm_shapes.push_back({M, N, K, K, N, N, nullptr, nullptr, nullptr});
}
auto f_host_tensor_descriptor =
......@@ -115,6 +113,10 @@ int main(int argc, char* argv[])
std::vector<Tensor<CDataType>> c_host_tensors;
std::vector<Tensor<CDataType>> c_device_tensors;
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < gemm_shapes.size(); i++)
......@@ -133,13 +135,10 @@ int main(int argc, char* argv[])
<< std::endl;
flop += std::size_t(2) * gemm_shapes[i].M * gemm_shapes[i].K * gemm_shapes[i].N;
num_btype += sizeof(ADataType) * gemm_shapes[i].M * gemm_shapes[i].K +
sizeof(BDataType) * gemm_shapes[i].K * gemm_shapes[i].N +
sizeof(CDataType) * gemm_shapes[i].M * gemm_shapes[i].N;
}
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize();
for(int i = 0; i < gemm_shapes.size(); i++)
{
switch(init_method)
{
case 0: break;
......@@ -157,38 +156,23 @@ int main(int argc, char* argv[])
}
}
DeviceMem a_tensors_device_buf(sizeof(ADataType) * A_size);
DeviceMem b_tensors_device_buf(sizeof(BDataType) * B_size);
DeviceMem c_tensors_device_buf(sizeof(CDataType) * C_size);
std::vector<ADataType> a_tensors_data, b_tensors_data, c_tensors_data;
A_size = 0;
B_size = 0;
C_size = 0;
for(int i = 0; i < gemm_shapes.size(); i++)
{
a_tensors_data.insert(
a_tensors_data.end(), a_tensors[i].mData.begin(), a_tensors[i].mData.end());
b_tensors_data.insert(
b_tensors_data.end(), b_tensors[i].mData.begin(), b_tensors[i].mData.end());
gemm_shapes[i].p_a =
static_cast<ADataType*>(a_tensors_device_buf.GetDeviceBuffer()) + A_size;
gemm_shapes[i].p_b =
static_cast<BDataType*>(b_tensors_device_buf.GetDeviceBuffer()) + B_size;
gemm_shapes[i].p_c =
static_cast<CDataType*>(c_tensors_device_buf.GetDeviceBuffer()) + C_size;
A_size += gemm_shapes[i].M * gemm_shapes[i].K;
B_size += gemm_shapes[i].N * gemm_shapes[i].K;
C_size += gemm_shapes[i].M * gemm_shapes[i].N;
a_tensors_device.push_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize()));
b_tensors_device.push_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * a_tensors[i].mDesc.GetElementSize()));
c_tensors_device.push_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize()));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
gemm_shapes[i].p_a = a_tensors_device[i]->GetDeviceBuffer();
gemm_shapes[i].p_b = b_tensors_device[i]->GetDeviceBuffer();
gemm_shapes[i].p_c = c_tensors_device[i]->GetDeviceBuffer();
}
a_tensors_device_buf.ToDevice(a_tensors_data.data());
b_tensors_device_buf.ToDevice(b_tensors_data.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
......@@ -214,24 +198,11 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
c_tensors_data.resize(C_size);
c_tensors_device_buf.FromDevice(c_tensors_data.data());
C_size = 0;
for(int i = 0; i < gemm_shapes.size(); i++)
{
memcpy(c_device_tensors[i].mData.data(),
c_tensors_data.data() + C_size,
c_device_tensors[i].mData.size() * sizeof(CDataType));
C_size += gemm_shapes[i].M * gemm_shapes[i].N;
}
if(do_verification)
{
for(int i = 0; i < gemm_shapes.size(); i++)
{
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data());
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
......
......@@ -70,7 +70,7 @@ template <typename AElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<GemmShape> gemm_shapes,
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<GemmShape>& gemm_shapes,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
......
......@@ -242,7 +242,7 @@ struct DeviceGroupedGemmXdl
// Argument
struct Argument : public BaseArgument
{
Argument(std::vector<GemmShape> gemm_shapes,
Argument(std::vector<GemmShape>& gemm_shapes,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
......@@ -360,8 +360,7 @@ struct DeviceGroupedGemmXdl
if(GridwiseGemm::CalculateHasMainK0BlockLoop(K0) != has_main_k0_block_loop)
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
throw std::runtime_error("wrong! not all gemm has_main_k0_block_loop");
}
}
});
......@@ -435,11 +434,17 @@ struct DeviceGroupedGemmXdl
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(arg.GemmShape_[0].a_grid_desc_k0_m_k1_,
arg.GemmShape_[0].b_grid_desc_k0_n_k1_,
arg.GemmShape_[0].c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
bool isValid = true;
for(int i = 0; i < arg.GemmShape_.size(); i++)
{
isValid &= GridwiseGemm::CheckValidity(arg.GemmShape_[i].a_grid_desc_k0_m_k1_,
arg.GemmShape_[i].b_grid_desc_k0_n_k1_,
arg.GemmShape_[i].c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
}
return isValid;
}
// polymorphic
......@@ -459,7 +464,7 @@ struct DeviceGroupedGemmXdl
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<GemmShape> gemm_shapes,
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<GemmShape>& gemm_shapes,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
......
......@@ -60,24 +60,22 @@ __global__ void
}
});
#else
const GemmDesc* gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_desc_);
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_desc_);
index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd)
group_id = (block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd &&
i < group_count)
? i
: group_id;
});
const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart;
const index_t a_offset_grp = gemm_desc_ptr[group_id].OffsetA;
const index_t b_offset_grp = gemm_desc_ptr[group_id].OffsetB;
const index_t c_offset_grp = gemm_desc_ptr[group_id].OffsetC;
GridwiseGemm::template Run<HasMainK0BlockLoop>(
p_a_grid + a_offset_grp,
p_b_grid + b_offset_grp,
p_c_grid + c_offset_grp,
gemm_desc_ptr[group_id].a_ptr,
gemm_desc_ptr[group_id].b_ptr,
gemm_desc_ptr[group_id].c_ptr,
p_shared,
gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_,
......
......@@ -12,6 +12,7 @@ struct DeviceMem
{
DeviceMem() = delete;
DeviceMem(std::size_t mem_size);
DeviceMem(const DeviceMem& p);
void* GetDeviceBuffer();
void ToDevice(const void* p);
void FromDevice(void* p);
......
......@@ -5,6 +5,12 @@ DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
DeviceMem::DeviceMem(const DeviceMem& p) : mpDeviceBuf(p.mpDeviceBuf), mMemSize(p.mMemSize)
{
// hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
// hipGetErrorString(hipMemcpy(mpDeviceBuf, p.mpDeviceBuf, mMemSize, hipMemcpyDeviceToDevice));
}
void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; }
void DeviceMem::ToDevice(const void* p)
......
......@@ -23,9 +23,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances =
std::tuple<
// clang-format off
using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
......@@ -48,13 +47,14 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances =
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
// clang-format on
>;
// clang-format on
>;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGroupedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{});
add_device_operation_instances(instances,
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{});
}
} // namespace device_grouped_gemm_instance
......
......@@ -16,15 +16,19 @@ namespace tensor_operation {
namespace device {
namespace device_grouped_gemm_instance {
using DeviceGroupedGemmNoOpPtr =
ck::tensor_operation::device::DeviceGroupedGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
//void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
//void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
//void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
using DeviceGroupedGemmNoOpPtr = ck::tensor_operation::device::DeviceGroupedGemmPtr<
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGroupedGemmNoOpPtr>&);
// void
// add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
// void
// add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
// void
// add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
} // namespace device_grouped_gemm_instance
} // namespace device
......@@ -41,15 +45,15 @@ template <typename ADataType,
typename BLayout,
typename CLayout>
void profile_grouped_gemm_impl(int do_verification,
int init_method,
bool do_log,
int nrepeat,
std::vector<int> Ms,
std::vector<int> Ns,
std::vector<int> Ks,
std::vector<int> StrideAs,
std::vector<int> StrideBs,
std::vector<int> StrideCs)
int init_method,
bool do_log,
int nrepeat,
std::vector<int> Ms,
std::vector<int> Ns,
std::vector<int> Ks,
std::vector<int> StrideAs,
std::vector<int> StrideBs,
std::vector<int> StrideCs)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
......@@ -65,41 +69,48 @@ void profile_grouped_gemm_impl(int do_verification,
}
};
std::vector<Tensor<ADataType>> a_m_k;
std::vector<Tensor<BDataType>> b_k_n;
std::vector<Tensor<CDataType>> c_m_n;
std::vector<Tensor<CDataType>> c_m_n_device_results;
// int A_size = 0, B_size = 0, C_size = 0;
for(int i = 0; i < Ms.size(); i++)
{
a_m_k.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
Ms[i], Ks[i], StrideAs[i], ALayout{})));
b_k_n.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
Ks[i], Ns[i], StrideBs[i], BLayout{})));
c_m_n.push_back(Tensor<CDataType>(f_host_tensor_descriptor(
Ms[i], Ns[i], StrideCs[i], CLayout{})));
a_m_k.push_back(
Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
b_k_n.push_back(
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
c_m_n_device_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
std::cout << "a_m_k[" << i << "]:" << a_m_k[i].mDesc << std::endl;
std::cout << "b_k_n[" << i << "]:" << b_k_n[i].mDesc << std::endl;
std::cout << "c_m_n[" << i << "]:" << c_m_n[i].mDesc << std::endl;
std::cout << "c_m_n_device_results[" << i << "]:" << c_m_n_device_results[i].mDesc
<< std::endl;
std::size_t num_thread = std::thread::hardware_concurrency();
switch(init_method)
{
case 0: break;
case 1:
a_m_k[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
break;
default:
a_m_k[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
case 0: break;
case 1:
a_m_k[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
break;
default:
a_m_k[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
}
// set zero to c_device_buf
c_m_n[i].GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
}
c_m_n_device_results[i].GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
// A_size += a_m_k[i].mDesc.GetElementSpace();
// B_size += b_k_n[i].mDesc.GetElementSpace();
// C_size += c_m_n_device_results[i].mDesc.GetElementSpace();
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
......@@ -114,28 +125,112 @@ void profile_grouped_gemm_impl(int do_verification,
// }
std::vector<DeviceMem> a_device_buf, b_device_buf, c_device_buf;
//DeviceMem a_device_buf(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace());
//DeviceMem b_device_buf(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace());
//DeviceMem c_device_buf(sizeof(CDataType) * c_m_n[i].mDesc.GetElementSpace());
// std::vector<DeviceMem> a_device_buf, b_device_buf, c_device_buf;
std::vector<void*> a_device_buf, b_device_buf, c_device_buf;
// DeviceMem a_device_buf_(sizeof(ADataType) * A_size);
// DeviceMem b_device_buf_(sizeof(BDataType) * B_size);
// DeviceMem c_device_buf_(sizeof(CDataType) * C_size);
// std::vector<ADataType> a_tensors_data;
// std::vector<BDataType> b_tensors_data;
// std::vector<CDataType> c_tensors_data;
std::vector<GemmShape> gemm_shapes;
// A_size = 0;
// B_size = 0;
// C_size = 0;
for(int i = 0; i < Ms.size(); i++)
{
a_device_buf.push_back(DeviceMem(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace()));
b_device_buf.push_back(DeviceMem(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace()));
c_device_buf.push_back(DeviceMem(sizeof(CDataType) * c_m_n[i].mDesc.GetElementSpace()));
a_device_buf[i].ToDevice(a_m_k[i].mData.data());
b_device_buf[i].ToDevice(b_k_n[i].mData.data());
c_device_buf[i].ToDevice(c_m_n[i].mData.data());
// a_tensors_data.insert(a_tensors_data.end(), a_m_k[i].mData.begin(),
// a_m_k[i].mData.end()); b_tensors_data.insert(b_tensors_data.end(),
// b_k_n[i].mData.begin(), b_k_n[i].mData.end());
// c_tensors_data.insert(c_tensors_data.end(), c_m_n_device_results[i].mData.begin(),
// c_m_n_device_results[i].mData.end());
void *a_device_buf_, *b_device_buf_, *c_device_buf_;
hipGetErrorString(hipMalloc(static_cast<void**>(&a_device_buf_),
sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace()));
hipGetErrorString(hipMalloc(static_cast<void**>(&b_device_buf_),
sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace()));
hipGetErrorString(
hipMalloc(static_cast<void**>(&c_device_buf_),
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpace()));
// DeviceMem a_device_buf_(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace());
// DeviceMem b_device_buf_(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace());
// DeviceMem c_device_buf_(sizeof(CDataType) *
// c_m_n_device_results[i].mDesc.GetElementSpace());
hipGetErrorString(hipMemcpy(a_device_buf_,
a_m_k[i].mData.data(),
sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace(),
hipMemcpyHostToDevice));
hipGetErrorString(hipMemcpy(b_device_buf_,
b_k_n[i].mData.data(),
sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace(),
hipMemcpyHostToDevice));
hipGetErrorString(
hipMemcpy(c_device_buf_,
c_m_n_device_results[i].mData.data(),
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpace(),
hipMemcpyHostToDevice));
// a_device_buf_.ToDevice(a_m_k[i].mData.data());
// b_device_buf_.ToDevice(b_k_n[i].mData.data());
// c_device_buf_.ToDevice(c_m_n_device_results[i].mData.data());
a_device_buf.push_back(a_device_buf_);
b_device_buf.push_back(b_device_buf_);
c_device_buf.push_back(c_device_buf_);
// a_device_buf.push_back(a_device_buf_);
// b_device_buf.push_back(b_device_buf_);
// c_device_buf.push_back(c_device_buf_);
// gemm_shapes.push_back({Ms[i],
// Ns[i],
// Ks[i],
// StrideAs[i],
// StrideBs[i],
// StrideCs[i],
// a_device_buf[i].GetDeviceBuffer(),
// b_device_buf[i].GetDeviceBuffer(),
// c_device_buf[i].GetDeviceBuffer()});
// printf("%p %p %p\n",
// a_device_buf[i].GetDeviceBuffer(),
// b_device_buf[i].GetDeviceBuffer(),
// c_device_buf[i].GetDeviceBuffer());
gemm_shapes.push_back({Ms[i],
Ns[i],
Ks[i],
StrideAs[i],
StrideBs[i],
StrideCs[i],
a_device_buf_,
b_device_buf_,
c_device_buf_});
// A_size += a_m_k[i].mDesc.GetElementSpace();
// B_size += b_k_n[i].mDesc.GetElementSpace();
// C_size += c_m_n_device_results[i].mDesc.GetElementSpace();
}
// a_device_buf_.ToDevice(a_tensors_data.data());
// b_device_buf_.ToDevice(b_tensors_data.data());
// c_device_buf_.ToDevice(c_tensors_data.data());
// add device GEMM instances
std::vector<ck::tensor_operation::device::device_grouped_gemm_instance::DeviceGroupedGemmNoOpPtr> gemm_ptrs;
std::vector<
ck::tensor_operation::device::device_grouped_gemm_instance::DeviceGroupedGemmNoOpPtr>
gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
is_same<CDataType, half_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
......@@ -143,7 +238,6 @@ void profile_grouped_gemm_impl(int do_verification,
{
ck::tensor_operation::device::device_grouped_gemm_instance::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
}
#if 0
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
......@@ -216,24 +310,15 @@ void profile_grouped_gemm_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
#if 0
#if 1
// profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
gemm_ptr->MakeArgumentPointer(gemm_shapes,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
KBatch);
ck::tensor_operation::element_wise::PassThrough{});
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
......@@ -243,6 +328,7 @@ void profile_grouped_gemm_impl(int do_verification,
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat);
#if 0
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
......@@ -262,54 +348,36 @@ void profile_grouped_gemm_impl(int do_verification,
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
#endif
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
if constexpr(is_same<ADataType, ck::bhalf_t>::value &&
is_same<BDataType, ck::bhalf_t>::value &&
is_same<CDataType, ck::bhalf_t>::value)
{
Tensor<float> a_f32_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<float> b_f32_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<float> c_m_n_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<float> c_m_n_device_f32_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
// c_tensors_data.resize(C_size);
bf16_to_f32_(a_m_k, a_f32_m_k);
bf16_to_f32_(b_k_n, b_f32_k_n);
bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result);
// c_device_buf_.FromDevice(c_tensors_data.data());
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<float, float, float, AElementOp, BElementOp, CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
// C_size = 0;
// for(int i = 0; i < gemm_shapes.size(); i++)
//{
// memcpy(c_m_n_device_results[i].mData.data(),
// c_tensors_data.data() + C_size,
// c_m_n_device_results[i].mDesc.GetElementSpace() * sizeof(CDataType));
auto ref_argument = ref_gemm.MakeArgument(a_f32_m_k,
b_f32_k_n,
c_m_n_host_result,
a_element_op,
b_element_op,
c_element_op);
// C_size += c_m_n_device_results[i].mDesc.GetElementSpace();
//}
ref_invoker.Run(ref_argument);
for(int i = 0; i < gemm_shapes.size(); i++)
{
hipGetErrorString(hipMemcpy(c_m_n_device_results[i].mData.data(),
c_device_buf[i],
sizeof(CDataType) *
c_m_n_device_results[i].mDesc.GetElementSpace(),
hipMemcpyDeviceToHost));
check_error(c_m_n_host_result, c_m_n_device_f32_result);
// hipGetErrorString(hipFree(c_device_buf[i]));
if(do_log)
{
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
}
}
else
{
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}));
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<ADataType,
......@@ -322,27 +390,30 @@ void profile_grouped_gemm_impl(int do_verification,
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
b_k_n[i],
c_m_n_host_result,
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
check_error(c_m_n_host_result, c_m_n_device_result);
check_error(c_m_n_host_result, c_m_n_device_results[i]);
if(do_log)
{
// LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
//<< std::endl;
// LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",") <<
// std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
std::cout << "c_device: ", c_m_n_device_results[i].mData, ",")
<< std::endl;
// LogRangeAsType<float>(
// std::cout << "c_host : ", c_m_n_host_result.mData, ",")
//<< std::endl;
}
}
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
}
}
}
else
......
......@@ -26,7 +26,7 @@ enum GemmDataType
INT8_INT8_INT8, // 3
};
std::vector<int> stringToArray(char *input)
std::vector<int> stringToArray(char* input)
{
std::vector<int> out;
......@@ -34,7 +34,8 @@ std::vector<int> stringToArray(char *input)
std::string item;
while (std::getline(in, item, ',')) {
while(std::getline(in, item, ','))
{
out.push_back(std::stoi(item));
}
......@@ -69,30 +70,33 @@ int profile_grouped_gemm(int argc, char* argv[])
const auto Ms = stringToArray(argv[8]);
const auto Ns = stringToArray(argv[9]);
const auto Ks = stringToArray(argv[10]);
const auto StrideAs = stringToArray(argv[11]);
const auto StrideBs = stringToArray(argv[12]);
const auto StrideCs = stringToArray(argv[13]);
for(int i = 0; i < Ms.size(); i++)
{
std::cout << "M: " << Ms[i] << " N: " << Ns[i] << " K: " << Ks[i] << std::endl;
}
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs);
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
init_method,
do_log,
nrepeat,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs);
}
#if 0
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment