Commit 1a2a1679 authored by carlushuang's avatar carlushuang
Browse files

initial stream-k implementation with example

parent 8bb2bb4a
......@@ -44,3 +44,4 @@ if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
endif()
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
......@@ -39,7 +39,9 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v1>;
// DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
// clang-format on
using DeviceGemmInstance = DeviceGemmInstance1;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = float;
using CDataType = ck::half_t;
using F16 = ck::half_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
// clang-format off
using DeviceGemmStreamK = ck::tensor_operation::device::DeviceGemmXdlStreamK
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// // clang-format on
// clang-format on
using DeviceGemmInstance = DeviceGemmStreamK;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
......@@ -11,7 +11,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
auto [M, N, K, StrideA, StrideB, StrideC] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
......@@ -25,12 +25,35 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
}
};
auto f_get_default_stride = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0) {
// give a chance if stride is zero, return a defalt packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0: break;
case 0:
ck::utils::FillConstant<ADataType>{1.f}(a_m_k);
ck::utils::FillConstant<BDataType>{1.f}(b_k_n);
break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
......
......@@ -94,6 +94,21 @@ struct ThreadGroupTensorSliceTransfer_v4r1
}
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
}
}
template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <typename ThreadGroup,
typename ElementwiseOperation,
typename SliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct ThreadGroupTensorSliceTransfer_v6r1r2
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r1r2(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
dst_desc,
make_zero_multi_index<nDim>(),
element_op)
{
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename SrcBuffer, typename DstBuffer, InMemoryDataOperationEnum DstInMemOp>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.template Run<SrcBuffer, DstBuffer, DstInMemOp>(
src_desc, src_buf, dst_desc, dst_buf);
}
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v6r1r2<SrcData,
DstData,
SrcDesc,
DstDesc,
ElementwiseOperation,
decltype(thread_slice_lengths),
DimAccessOrder,
VectorDim,
ScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
ck::index_t ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
ck::index_t BBlockLdsAddExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL>
struct DeviceGemmXdlStreamK : public DeviceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk<
BlockSize,
BlockToCTileMap_GemmStreamK<MPerBlock, NPerBlock, K0PerBlock * K1>,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXDL,
NPerXDL,
K1,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXDL,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
using Argument = typename GridwiseGemm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
void Print(const Argument& karg) { karg.Print(); }
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
Print(karg);
}
if(!GridwiseGemm::CheckValidity(karg))
{
throw std::runtime_error(
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid "
"setting");
}
dim3 grid_dims = karg.block_mapping.get_grid_dims();
float ave_time = 0;
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
// TODO: remove clear buffer for streamk kernels
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
ave_time =
launch_and_time_kernel(stream_config, kernel, grid_dims, dim3(BlockSize), 0, karg);
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& karg)
{
return GridwiseGemm::CheckValidity(karg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
{
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
int occupancy, num_cu;
hipError_t rtn;
rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
hip_check_error(rtn);
hipDeviceProp_t dev_prop;
hipDevice_t dev;
rtn = hipGetDevice(&dev);
hip_check_error(rtn);
rtn = hipGetDeviceProperties(&dev_prop, dev);
hip_check_error(rtn);
num_cu = dev_prop.multiProcessorCount;
printf("XXX occupancy:%d, num_cu:%d, BLOCK_SIZE:%d, LDS:%d\n",
occupancy,
num_cu,
BlockSize,
GridwiseGemm::GetSharedMemoryNumberOfByte());
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
static_cast<uint32_t>(num_cu),
static_cast<uint32_t>(occupancy)};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
{
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
int occupancy, num_cu;
hipError_t rtn;
rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
hip_check_error(rtn);
hipDeviceProp_t dev_prop;
hipDevice_t dev;
rtn = hipGetDevice(&dev);
hip_check_error(rtn);
rtn = hipGetDeviceProperties(&dev_prop, dev);
hip_check_error(rtn);
num_cu = dev_prop.multiProcessorCount;
return std::make_unique<Argument>(reinterpret_cast<const ADataType*>(p_a),
reinterpret_cast<const BDataType*>(p_b),
reinterpret_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
static_cast<uint32_t>(num_cu),
static_cast<uint32_t>(occupancy));
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); }
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,6 +7,7 @@
#include "ck/utility/number.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include <limits>
namespace ck {
......@@ -635,4 +636,236 @@ struct BlockToCTileMap_3DGrid_KSplit
}
};
template <uint32_t MPerBlock_, uint32_t NPerBlock_, uint32_t KPerBlock_>
struct BlockToCTileMap_GemmStreamK
{
static constexpr uint32_t min_k_iters_per_sk_block = 2;
static constexpr uint32_t MPerBlock = MPerBlock_;
static constexpr uint32_t NPerBlock = NPerBlock_;
static constexpr uint32_t KPerBlock = KPerBlock_;
//--------------------------------------
// pass to device
uint32_t sk_num_blocks;
uint32_t sk_num_big_blocks;
uint32_t sk_total_iters;
uint32_t dp_start_block_idx;
uint32_t dp_iters_per_block;
uint32_t dp_num_blocks;
uint32_t k_iters_per_big_block;
MDiv k_iters_per_tile;
MDiv n_tiles;
//--------------------------------------
// prefer construct on host
BlockToCTileMap_GemmStreamK(
uint32_t m, uint32_t n, uint32_t k, uint32_t num_cu, uint32_t occupancy)
{
uint32_t num_tiles =
math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock);
k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock));
// one cu can hold one wg at one time, from the whole chip's point of view
// if number of wg is same as num_cu, we call it 1 dispatch
// if number of wg is 2x num_cu, we call it 2 dispatches.
// one dispatch can deliever wg same as num_cu (full dispatch), or less than num_cu (partial
// dispatch)
//
uint32_t full_dispatches = num_tiles / num_cu;
uint32_t full_dispatch_tiles = full_dispatches * num_cu;
uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles;
uint32_t sk_occupancy = occupancy;
uint32_t dp_tiles = full_dispatch_tiles;
uint32_t sk_tiles = partial_dispatche_tiles;
if(full_dispatches < occupancy)
{
// in this case, we allocate all blocks as sk blocks
// sk_occupancy = occupancy - full_dispatches;
sk_occupancy = 1; // TODO: single occ seems better
dp_tiles = full_dispatch_tiles;
sk_tiles = partial_dispatche_tiles;
}
else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
{
// e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
// occupancy = 3, full_dispatches = 5, 8, 11 ...
// occupancy = 4, full_dispatches = 7, 11 ...
sk_occupancy = 1; // left 1 slot for sk occupancy
dp_tiles = full_dispatch_tiles;
sk_tiles = partial_dispatche_tiles;
}
else
{
// others, we reduce 1 dispatch from dp, together with partial dispatch,
// to construct sk dispatch
sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy);
dp_tiles = full_dispatch_tiles - num_cu;
sk_tiles = partial_dispatche_tiles + num_cu;
}
// dp_num_blocks = dp_tiles;
// dp_start_block_idx = num_cu * sk_occupancy;
dp_iters_per_block = k_iters_per_tile.get();
sk_total_iters = k_iters_per_tile.get() * sk_tiles;
// printf("num_tiles:%d, full_dispatches:%d, full_dispatch_tiles:%d,
// partial_dispatche_tiles:%d\n",
// num_tiles, full_dispatches, full_dispatch_tiles, partial_dispatche_tiles);
{
uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
uint32_t max_sk_tiles =
(sk_tiles >= num_cu) ? num_cu * sk_occupancy
: math::min(num_cu, sk_total_iters / min_k_iters_per_sk_block);
// if use dp for sk-block, how many iters do we need
uint32_t dp_for_sk_iters = k_iters_per_tile.get();
uint32_t best_sk_score =
std::numeric_limits<int>::max(); // we need to find the smallest sk iters
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
tentative_sk_blocks++)
{
uint32_t tentative_sk_iters_per_block =
(sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;
// TODO: carefully adjust this parameter
// the more sk_blocks_per_tile, the worse the overhead
uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile;
if(tentative_sk_blocks % sk_tiles != 0)
{
// penalty for uneven divide
cross_sk_blocks_overhead +=
sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
}
uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;
if(tentative_sk_score < best_sk_score)
{
best_sk_score = tentative_sk_score;
sk_num_blocks = tentative_sk_blocks;
}
}
if(best_sk_score >= dp_for_sk_iters)
{
sk_num_blocks = 0;
}
if(sk_num_blocks == 0)
{
sk_num_big_blocks = 0;
k_iters_per_big_block = 0;
dp_num_blocks = num_tiles; // all tile to be dp block
dp_start_block_idx = 0;
sk_total_iters = 0; // clear this tiles
}
else
{
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
// we need to decide how many iters for each sk block
// let m = k_iters_per_sk_block
// some of the sk block (little) will cover m iters, some (big) will cover m+1
// we have
// 1) l + b = sk_blocks
// 2) l * m + b * (m + 1) = sk_total_iters
// => (l + b) * m + b = sk_total_iters
// => sk_blocks * m + b = sk_total_iters
// => b = sk_total_iters - m * sk_blocks
// NOTE: big could be zero
uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
k_iters_per_big_block = k_iters_per_sk_block + 1;
dp_num_blocks = dp_tiles;
dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
}
}
n_tiles = MDiv(math::integer_divide_ceil(n, NPerBlock));
printf("cu:%d, occupancy:%d, grids:%d, sk_num_big_blocks:%d, sk_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d\n",
num_cu,
occupancy,
get_grid_dims().x,
sk_num_big_blocks,
sk_num_blocks,
sk_total_iters,
dp_start_block_idx,
dp_iters_per_block,
dp_num_blocks,
k_iters_per_tile.get(),
k_iters_per_big_block);
}
__host__ __device__ dim3 get_grid_dims() const
{
return dim3(dp_start_block_idx + dp_num_blocks, 1, 1);
}
__device__ uint32_t get_block_idx() const
{
// TODO: swizzle block index for better locality
return __builtin_amdgcn_readfirstlane(blockIdx.x);
}
__device__ void
get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const
{
if(block_idx < sk_num_big_blocks)
{
iter_start = block_idx * k_iters_per_big_block;
iter_end = iter_start + k_iters_per_big_block;
}
else if(block_idx < sk_num_blocks)
{
iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
(block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
iter_end = iter_start + (k_iters_per_big_block - 1);
}
else if(block_idx >= dp_start_block_idx)
{
iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
iter_end = iter_start + dp_iters_per_block;
}
}
__device__ uint32_t get_current_iter_length(uint32_t iter_start,
uint32_t iter_end,
uint32_t total_iter_length) const
{
uint32_t iter_length_mod, iter_length_quo /*unused*/;
k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
uint32_t current_iter_length = math::min(
iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length);
return current_iter_length;
}
__device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); }
__device__ void
get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const
{
k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
}
__device__ auto tile_to_spatial(uint32_t tile_idx) const
{
// TODO:
uint32_t m_tile_idx, n_tile_idx;
n_tiles.divmod(tile_idx, m_tile_idx, n_tile_idx);
return make_tuple(m_tile_idx, n_tile_idx);
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace ck {
struct GridwiseGemmPipeline_v3
{
__host__ __device__ static constexpr bool IsSupported(index_t)
{
// TODO: improve applicability
return true;
}
template <typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// global read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// LDS write 0
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
num_loop--;
while(num_loop > 0)
{
block_sync_lds();
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
num_loop--;
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp"
namespace ck {
template <typename GridwiseGemm>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_streamk(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
GridwiseGemm::Run(karg, static_cast<void*>(p_shared));
#else
ignore = karg;
ignore = b2c_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <index_t BlockSize,
typename Block2CTileMap_,
typename FloatAB,
typename FloatAcc,
typename FloatC,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t K1Value,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
static constexpr auto M01 = 1;
static constexpr auto N01 = 1;
static constexpr auto KPerBlock = K0PerBlock * K1;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using Block2CTileMap = Block2CTileMap_;
struct Argument : public ck::tensor_operation::device::BaseArgument
{
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
FloatC* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
Block2CTileMap block_mapping;
Argument(const FloatAB* p_a_grid_,
const FloatAB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
uint32_t num_cu,
uint32_t occupancy)
: p_a_grid(p_a_grid_),
p_b_grid(p_b_grid_),
p_c_grid(p_c_grid_),
M(M_),
N(N_),
K(K_),
StrideA(StrideA_),
StrideB(StrideB_),
StrideC(StrideC_),
block_mapping(M, N, K, num_cu, occupancy)
{
}
void Print() const
{
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << std::endl;
}
};
__host__ __device__ static auto CalculateGridSize(const Argument& karg)
{
return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock),
math::integer_divide_ceil(karg.M, MPerBlock),
karg.k_batch);
}
#if 0
// prefer this to be called on host
__host__ __device__ static auto CalculateMPadded(index_t M)
{
return (M + MPerBlock - 1) / MPerBlock * MPerBlock;
}
__host__ __device__ static auto CalculateNPadded(index_t N)
{
return (N + NPerBlock - 1) / NPerBlock * NPerBlock;
}
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1)
{
// k_batch * k0 * k0_per_block * k1
auto K_t = K_Batch * K0PerBlock * K1;
return (K + K_t - 1) / K_t * K0PerBlock;
}
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K0 = CalculateK0(K, K_Batch);
return K_Batch * K0 * K1;
}
#endif
__host__ __device__ static auto CalculateK0(index_t KPad) { return KPad / K1; }
__host__ __device__ static auto
MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA)
{
const index_t K0 = CalculateK0(KPad);
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
__host__ __device__ static auto
MakeBGridDescriptor_K0_N_K1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB)
{
const index_t K0 = CalculateK0(KPad);
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
return transform_tensor_descriptor(c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * K1, K1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * K1, K1, I1));
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = K1;
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned =
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto c_block_size =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
c_block_size * sizeof(FloatC));
}
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
{
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
return false;
}
else
{
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
return false;
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
return false;
}
else
{
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
return false;
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
return false;
}
else
{
if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
return false;
}
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = K0 > K0PerBlock;
return has_main_k0_block_loop;
}
template <typename CGridDesc>
__host__ __device__ static constexpr auto
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
return transform_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
}
// return block_id to C matrix tile idx (m0, n0) mapping
template <typename CGridDesc>
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
{
return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc>(
c_m_n_grid_desc, 8, KBatch);
}
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
return make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMRepeatPerShuffle * MWave * MPerXDL>{},
I1,
Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{}));
}
// return block_id to C matrix tile idx (m0, n0, k_split) mapping
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap()
{
return BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>();
}
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>;
using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap())>;
__device__ static void Run(const Argument& karg, void* __restrict__ p_shared_block)
{
const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid;
uint32_t m = karg.M;
uint32_t n = karg.N;
uint32_t k = karg.K;
uint32_t pad_m = (m + MPerBlock - 1) / MPerBlock * MPerBlock;
uint32_t pad_n = (n + NPerBlock - 1) / NPerBlock * NPerBlock;
uint32_t pad_k = (k + KPerBlock - 1) / KPerBlock * KPerBlock;
uint32_t stride_a = karg.StrideA;
uint32_t stride_b = karg.StrideB;
uint32_t stride_c = karg.StrideC;
const auto a_k0_m_k1_grid_desc = MakeAGridDescriptor_K0_M_K1(m, pad_m, k, pad_k, stride_a);
const auto b_k0_n_k1_grid_desc = MakeBGridDescriptor_K0_N_K1(k, pad_k, n, pad_n, stride_b);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(m, pad_m, n, pad_n, stride_c);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
const AElementwiseOperation a_element_op = AElementwiseOperation{};
const BElementwiseOperation b_element_op = BElementwiseOperation{};
const CElementwiseOperation c_element_op = CElementwiseOperation{};
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// lds max alignment
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
K1>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = static_cast<FloatAB*>(p_shared_block);
FloatAB* p_b_block = static_cast<FloatAB*>(p_shared_block) + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize());
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3();
auto& block_mapping = karg.block_mapping;
uint32_t block_idx = block_mapping.get_block_idx();
bool is_sk_block = block_idx < block_mapping.sk_num_blocks;
bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx;
uint32_t iter_start, iter_end;
block_mapping.get_block_itr(block_idx, iter_start, iter_end);
uint32_t total_iter_length = iter_end - iter_start;
// if(threadIdx.x == 0)
// printf("xxx bid:%d\n", static_cast<int>(blockIdx.x));
if(!is_sk_block && !is_dp_block)
return;
while(true)
{
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
block_mapping.get_current_iter_length(iter_start, iter_end, total_iter_length));
uint32_t tile_idx, iter_offset;
block_mapping.get_tile_idx_with_offset(iter_end - 1, tile_idx, iter_offset);
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
auto spatial_idx = block_mapping.tile_to_spatial(tile_idx);
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(spatial_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(spatial_idx[I1] * NPerBlock);
const index_t k0_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock);
// if(threadIdx.x == 0)
// printf("[%s], bid:%d, block_idx:%d, tile_idx:%d(%d, %d, %d), iter_start:%d(%d |
// %d), iter_end:%d, len:%d\n",
// is_sk_block ? "sk_block" : (is_dp_block ? "dp_block" : "other "),
// static_cast<int>(blockIdx.x), block_idx, tile_idx, m_block_data_idx_on_grid,
// n_block_data_idx_on_grid, k0_block_data_idx_on_grid, iter_end -
// current_iter_length, iter_offset, iter_start, iter_end, current_iter_length);
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k0_m_k1_grid_desc),
decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k0_m_k1_grid_desc,
make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k0_n_k1_grid_desc),
decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k0_n_k1_grid_desc,
make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
const index_t num_k_block_main_loop = current_iter_length;
gridwise_gemm_pipeline.Run(a_k0_m_k1_grid_desc,
a_block_desc_k0_m_k1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_k0_n_k1_grid_desc,
b_block_desc_k0_n_k1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// output: register to global memory
{
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatC*>(p_shared_block),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_freeze_transform(I0), // freeze mblock
make_unmerge_transform(
make_tuple(CShuffleMRepeatPerShuffle,
M1,
M2,
M3,
M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL
make_freeze_transform(I0), // freeze nblock
make_unmerge_transform(
make_tuple(CShuffleNRepeatPerShuffle,
N1,
N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{},
Sequence<0, 2, 4, 5, 6>{},
Sequence<>{},
Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// VGPR to LDS
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// LDS to global
auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerXDL,
1,
CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatC, // typename SrcData,
FloatC, // typename DstData,
decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun
{c_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
0,
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
0),
c_element_op};
constexpr auto mxdlperwave_forward_step =
make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0);
constexpr auto nxdlperwave_forward_step =
make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL);
constexpr auto nxdlperwave_backward_step =
make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL);
static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
constexpr auto mxdlperwave = mxdlperwave_iter;
static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
constexpr bool nxdlperwave_forward_sweep =
(mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
constexpr index_t nxdlperwave_value =
nxdlperwave_forward_sweep
? nxdlperwave_iter
: (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
// make sure it's safe to do ds_write
block_sync_lds();
// VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_buf);
// make sure it's safe to do ds_read
block_sync_lds();
// LDS to global
if(is_dp_block)
c_block_copy_lds_to_global.template Run<decltype(c_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::Set>(
c_block_desc_mblock_mperblock_nblock_nperblock,
c_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
else if(is_sk_block)
c_block_copy_lds_to_global
.template Run<decltype(c_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
c_block_desc_mblock_mperblock_nblock_nperblock,
c_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
// move on nxdlperwave dimension
if constexpr(nxdlperwave_forward_sweep &&
(nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
nxdlperwave_forward_step);
}
else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
nxdlperwave_backward_step);
}
});
// move on mxdlperwave dimension
if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
mxdlperwave_forward_step);
}
});
}
// exit condition
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
// make sure next loop LDS is ready for use
block_sync_lds();
}
// if(threadIdx.x == 0)
// printf("xxx bid:%d, xx_total_iter_length:%d \n", static_cast<int>(blockIdx.x),
// xx_total_iter_length);
}
template <typename Layout>
struct LStr
{
static std::string Get() { return ""; }
};
template <>
struct LStr<ck::tensor_layout::gemm::RowMajor>
{
static std::string Get() { return "R"; }
};
template <>
struct LStr<ck::tensor_layout::gemm::ColumnMajor>
{
static std::string Get() { return "C"; }
};
static std::string GetTypeString()
{
auto str = std::stringstream();
// clang-format off
str << "GemmXdlStreamK_"
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< "_"
<< "B" << BlockSize << "_"
<< "Vec" << ABlockTransferSrcScalarPerVector << "x"
<< BBlockTransferSrcScalarPerVector << "x"
<< CBlockTransferScalarPerVector_NWaveNPerXDL << "_"
<< MPerBlock << "x"
<< NPerBlock << "x"
<< K0PerBlock << "x"
<< K1 ;
// clang-format on
return str.str();
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
namespace ck {
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
// instead
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
typename SliceLengths,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
bool SrcResetCoordinateAfterRun,
bool DstResetCoordinateAfterRun>
struct ThreadwiseTensorSliceTransfer_v6r1r2
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v6r1r2(
const SrcDesc& src_desc,
const Index& src_slice_origin,
const DstDesc& dst_desc,
const Index& dst_slice_origin,
const ElementwiseOperation& element_op)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
element_op_(element_op)
{
static_assert(SliceLengths::At(Number<VectorDim>{}) % ScalarPerVector == 0,
"wrong! cannot evenly divide");
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
}
template <typename SrcBuffer, typename DstBuffer, InMemoryDataOperationEnum DstInMemOp>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(scalar_per_access)>>;
// loop over space-filling curve
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, num_access, 1>{}([&](auto idx_1d) {
using src_vector_type = vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
using dst_vector_type = vector_type_maker_t<DstData, ScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
// copy data from src_buf into src_vector_container
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
auto dst_vector_container = dst_vector_type{};
// apply pointwise operation
static_for<0, ScalarPerVector, 1>{}([&](auto i) {
SrcData v;
// apply element-wise operation
element_op_(v, src_vector_container.template AsType<SrcData>()[i]);
// apply type convert
dst_vector_container.template AsType<DstData>()(i) = type_convert<DstData>(v);
});
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
// copy data from dst_vector into dst_buf
dst_buf.template Update<DstInMemOp, dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
// move coordinate
if constexpr(idx_1d.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
move_tensor_coordinate(
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
});
// move coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_step =
make_tensor_coordinate_step(src_desc, GetCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
}
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_reset_step =
make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
}
}
__device__ static constexpr auto GetCoordinateResetStep()
{
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(scalar_per_access)>>;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr(num_access == 0)
{
return typename SpaceFillingCurve::Index{};
}
else
{
constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_step;
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = SrcResetCoordinateAfterRun
? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx = DstResetCoordinateAfterRun
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
}
private:
SrcCoord src_coord_;
DstCoord dst_coord_;
const ElementwiseOperation element_op_;
};
} // namespace ck
......@@ -157,4 +157,42 @@ struct MagicDivision
}
};
struct MDiv
{
// 1 dword -> 3 dword storage
uint32_t divisor;
uint32_t multiplier;
uint32_t shift;
// prefer construct on host
__host__ __device__ MDiv(uint32_t divisor_) : divisor(divisor_)
{
ck::tie(multiplier, shift) = MagicDivision::CalculateMagicNumbers(divisor_);
}
__host__ __device__ MDiv() : divisor(0), multiplier(0), shift(0) {}
__host__ __device__ void update(uint32_t divisor_)
{
divisor = divisor_;
ck::tie(multiplier, shift) = MagicDivision::CalculateMagicNumbers(divisor_);
}
__host__ __device__ uint32_t div(uint32_t dividend) const
{
return MagicDivision::DoMagicDivision(dividend, multiplier, shift);
}
__host__ __device__ void
divmod(uint32_t dividend, uint32_t& quotient, uint32_t& remainder) const
{
quotient = div(dividend);
remainder = dividend - (quotient * divisor);
}
__host__ __device__ uint32_t operator/(uint32_t dividend) const { return div(dividend); }
__host__ __device__ uint32_t get() const { return divisor; }
};
} // namespace ck
......@@ -55,7 +55,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count > 0)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......@@ -102,7 +102,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count > 0)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......@@ -148,7 +148,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count > 0)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......@@ -199,7 +199,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count > 0)
{
std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
<< std::endl;
......
#include <stdio.h>
#include <string>
#include <algorithm>
#include <vector>
#include <limits>
#include "simple_args.h"
simple_args_t create_arg(int argc, char** argv)
{
simple_args_t args;
args.insert("m", "1024", "matrix m")
.insert("n", "1024", "matrix n")
.insert("k", "1024", "matrix k")
.insert("m_per_block", "128", "m_per_block")
.insert("n_per_block", "128", "n_per_block")
.insert("k_per_block", "32", "k_per_block")
.insert("num_cu", "104", "num cu")
.insert("occupancy", "2", "occupancy")
.parse(argc, argv);
return args;
}
namespace impl {
template <typename T>
T integer_divide_ceil(T n, T d)
{
return (n + d - 1) / d;
}
template <typename T>
T min(T a, T b)
{
return a > b ? b : a;
}
template <typename T>
T max(T a, T b)
{
return a > b ? a : b;
}
} // namespace impl
struct block_dispatcher_t
{
public:
uint32_t m_per_block;
uint32_t n_per_block;
uint32_t k_per_block;
uint32_t num_cu;
uint32_t occupancy;
uint32_t m;
uint32_t n;
uint32_t k;
//--------------------------------------
uint32_t sk_num_blocks;
uint32_t sk_num_big_blocks;
uint32_t sk_total_iters;
// uint32_t sk_num_blocks_per_tile; // how many
uint32_t dp_start_block_idx;
uint32_t dp_iters_per_block;
uint32_t dp_num_blocks;
uint32_t k_iters_per_tile;
uint32_t k_iters_per_big_block;
//--------------------------------------
static constexpr uint32_t min_k_iters_per_sk_block = 1;
void dump()
{
printf("%dx%dx%d(%dx%dx%d), cu:%d, occ:%d, grids:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, "
"dp_num_blocks:%d, k_iters_per_tile:%d, k_iters_per_big_block:%d\n",
m,
n,
k,
m_per_block,
n_per_block,
k_per_block,
num_cu,
occupancy,
get_grid_dims_x(),
sk_num_big_blocks,
sk_num_blocks,
sk_total_iters,
dp_start_block_idx,
dp_iters_per_block,
dp_num_blocks,
k_iters_per_tile,
k_iters_per_big_block);
}
block_dispatcher_t(uint32_t m_per_block_,
uint32_t n_per_block_,
uint32_t k_per_block_,
uint32_t num_cu_,
uint32_t occupancy_,
uint32_t m_,
uint32_t n_,
uint32_t k_)
: m_per_block(m_per_block_),
n_per_block(n_per_block_),
k_per_block(k_per_block_),
num_cu(num_cu_),
occupancy(occupancy_),
m(m_),
n(n_),
k(k_)
{
init();
}
uint32_t get_grid_dims_x() { return dp_start_block_idx + dp_num_blocks; }
uint32_t get_block_idx(uint32_t bid)
{
// block id is linearily allocated along sk blocks (dp blocks are fine)
// this function will compute blockIdx.x and the linear sk block mapping
// uint32_t block_idx = 0;
// if(bid < sk_num_big_blocks) {
// uint32_t current_k_iter = bid * k_iters_per_big_block;
// tile_idx = current_k_iter / k_iters_per_tile;
// }
return bid;
}
uint32_t get_current_itr(uint32_t block_idx)
{
uint32_t current_itr = 0;
if(block_idx < sk_num_big_blocks)
{
current_itr = block_idx * k_iters_per_big_block;
}
else if(block_idx < sk_num_blocks)
{
current_itr = (sk_num_big_blocks * k_iters_per_big_block) +
(block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
}
else if(block_idx >= dp_start_block_idx)
{
current_itr = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
}
return current_itr;
}
void get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end)
{
if(block_idx < sk_num_big_blocks)
{
iter_start = block_idx * k_iters_per_big_block;
iter_end = iter_start + k_iters_per_big_block;
}
else if(block_idx < sk_num_blocks)
{
iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
(block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
iter_end = iter_start + (k_iters_per_big_block - 1);
}
else if(block_idx >= dp_start_block_idx)
{
iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
iter_end = iter_start + dp_iters_per_block;
}
}
private:
void init()
{
uint32_t num_tiles =
impl::integer_divide_ceil(m, m_per_block) * impl::integer_divide_ceil(n, n_per_block);
k_iters_per_tile = impl::integer_divide_ceil(k, k_per_block);
// one cu can hold one wg at one time, from the whole chip's point of view
// if number of wg is same as num_cu, we call it 1 dispatch
// if number of wg is 2x num_cu, we call it 2 dispatches.
// one dispatch can deliever wg same as num_cu (full dispatch), or less than num_cu (partial
// dispatch)
//
uint32_t full_dispatches = num_tiles / num_cu;
uint32_t full_dispatch_tiles = full_dispatches * num_cu;
uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles;
uint32_t sk_occupancy = occupancy;
uint32_t dp_tiles = full_dispatch_tiles;
uint32_t sk_tiles = partial_dispatche_tiles;
if(full_dispatches < occupancy)
{
// in this case, we allocate all blocks as sk blocks
// sk_occupancy = occupancy - full_dispatches;
sk_occupancy = 1; // TODO: single occ seems better
dp_tiles = full_dispatch_tiles;
sk_tiles = partial_dispatche_tiles;
}
else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
{
// e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
// occupancy = 3, full_dispatches = 5, 8, 11 ...
// occupancy = 4, full_dispatches = 7, 11 ...
sk_occupancy = 1; // left 1 slot for sk occupancy
dp_tiles = full_dispatch_tiles;
sk_tiles = partial_dispatche_tiles;
}
else
{
// others, we reduce 1 dispatch from dp, together with partial dispatch,
// to construct sk dispatch
sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy);
dp_tiles = full_dispatch_tiles - num_cu;
sk_tiles = partial_dispatche_tiles + num_cu;
}
// dp_num_blocks = dp_tiles;
// dp_start_block_idx = num_cu * sk_occupancy;
dp_iters_per_block = k_iters_per_tile;
sk_total_iters = k_iters_per_tile * sk_tiles;
// printf("num_tiles:%d, full_dispatches:%d, full_dispatch_tiles:%d,
// partial_dispatche_tiles:%d\n",
// num_tiles, full_dispatches, full_dispatch_tiles, partial_dispatche_tiles);
{
uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
uint32_t max_sk_tiles =
(sk_tiles >= num_cu) ? num_cu * sk_occupancy
: impl::min(num_cu, sk_total_iters / min_k_iters_per_sk_block);
// if use dp for sk-block, how many iters do we need
uint32_t dp_for_sk_iters = k_iters_per_tile;
uint32_t best_sk_score =
std::numeric_limits<int>::max(); // we need to find the smallest sk iters
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
tentative_sk_blocks++)
{
uint32_t tentative_sk_iters_per_block =
(sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;
// TODO: carefully adjust this parameter
// the more sk_blocks_per_tile, the worse the overhead
uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile;
if(tentative_sk_blocks % sk_tiles != 0)
{
// penalty for uneven divide
cross_sk_blocks_overhead +=
sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
}
uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;
if(tentative_sk_score < best_sk_score)
{
best_sk_score = tentative_sk_score;
sk_num_blocks = tentative_sk_blocks;
}
}
if(best_sk_score >= dp_for_sk_iters)
{
sk_num_blocks = 0;
}
if(sk_num_blocks == 0)
{
sk_num_big_blocks = 0;
k_iters_per_big_block = 0;
dp_num_blocks = num_tiles; // all tile to be dp block
dp_start_block_idx = 0;
sk_total_iters = 0; // clear this tiles
}
else
{
uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
k_iters_per_big_block = k_iters_per_sk_block + 1;
dp_num_blocks = dp_tiles;
dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
}
}
}
};
struct tile_work_t
{
uint32_t tile_idx;
uint32_t iter_begin;
uint32_t k_begin;
uint32_t k_end;
uint32_t k_iters_remaining;
};
int main(int argc, char** argv)
{
simple_args_t arg = create_arg(argc, argv);
block_dispatcher_t block_dispatcher{arg.get_uint32("m_per_block"),
arg.get_uint32("n_per_block"),
arg.get_uint32("k_per_block"),
arg.get_uint32("num_cu"),
arg.get_uint32("occupancy"),
arg.get_uint32("m"),
arg.get_uint32("n"),
arg.get_uint32("k")};
block_dispatcher.dump();
// simulate actual kernel launch
uint32_t dim_x = block_dispatcher.get_grid_dims_x();
uint32_t total_k_iters =
impl::integer_divide_ceil(arg.get_uint32("k"), arg.get_uint32("k_per_block"));
uint32_t num_tiles =
impl::integer_divide_ceil(arg.get_uint32("m"), arg.get_uint32("m_per_block")) *
impl::integer_divide_ceil(arg.get_uint32("n"), arg.get_uint32("n_per_block"));
std::vector<int> valid_tile_record(num_tiles * total_k_iters);
for(uint32_t bid = 0; bid < dim_x; bid++)
{
uint32_t block_idx = block_dispatcher.get_block_idx(bid);
bool is_sk_block = block_idx < (block_dispatcher.sk_num_blocks);
bool is_dp_block = block_idx >= block_dispatcher.dp_start_block_idx;
uint32_t iter_start, iter_end;
block_dispatcher.get_block_itr(block_idx, iter_start, iter_end);
uint32_t total_iter_length = iter_end - iter_start;
while(true)
{
uint32_t iter_length_mod = iter_end % block_dispatcher.k_iters_per_tile;
uint32_t current_iter_length =
impl::min(iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod,
total_iter_length);
uint32_t tile_idx = (iter_end - 1) / block_dispatcher.k_iters_per_tile;
uint32_t tile_iter_start =
((iter_end - 1) % block_dispatcher.k_iters_per_tile) - current_iter_length + 1;
if(is_sk_block)
{
printf("[sk_block] bid:%3d, block_idx:%3d, tile_idx:%3d, iter_start:%d(%d | %d), "
"iter_end:%d (len:%d)\n",
bid,
block_idx,
tile_idx,
iter_end - current_iter_length,
tile_iter_start,
iter_start,
iter_end,
current_iter_length);
}
else if(is_dp_block)
{
printf("[dp_block] bid:%3d, block_idx:%3d, tile_idx:%3d, iter_start:%d(%d | %d), "
"iter_end:%d (len:%d)\n",
bid,
block_idx,
tile_idx,
iter_end - current_iter_length,
tile_iter_start,
iter_start,
iter_end,
current_iter_length);
}
else
{
printf("[other ] bid:%3d, block_idx:%3d\n", bid, block_idx);
}
// some validation check
for(auto i = iter_end - current_iter_length; i < iter_end; i++)
{
if(i >= valid_tile_record.size())
{
printf("unexpected, current iter:%d larger than max:%d\n",
i,
valid_tile_record.size());
return -1;
}
valid_tile_record[i] = 1;
}
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
}
}
int untouched = 0;
for(auto i = 0; i < valid_tile_record.size(); i++)
{
if(valid_tile_record[i] != 1)
{
printf("untouched at %d (%d)\n", i, valid_tile_record.size());
untouched++;
}
}
printf("untouched %d/%d, %s\n",
untouched,
valid_tile_record.size(),
untouched == 0 ? "valid" : "fail");
}
CC=g++
$CC -Wall -std=c++17 -Iinclude -O3 block_swizzle_test.cpp -o block_swizzle_test.exe
\ No newline at end of file
#pragma once
#include <iomanip>
#include <iostream>
#include <stdlib.h>
#include <string>
#include <unordered_map>
#include <vector>
#include <assert.h>
struct arg_content_t
{
std::string name; // key
std::string value;
std::string help_text;
};
class simple_args_t
{
public:
simple_args_t() {}
simple_args_t& insert(const std::string& name_,
const std::string& default_value_,
const std::string& help_text_)
{
arg_content_t arg{name_, default_value_, help_text_};
if(arg_map.count(arg.name) != 0)
{
std::cout << "arg:" << arg.name << "already exist" << std::endl;
}
else
{
arg_map[arg.name] = arg;
}
return *this;
}
void usage()
{
for(auto& content : arg_map)
{
std::vector<std::string> help_text_lines;
size_t pos = 0;
for(size_t next_pos = content.second.help_text.find('\n', pos);
next_pos != std::string::npos;)
{
help_text_lines.push_back(
std::string(content.second.help_text.begin() + pos,
content.second.help_text.begin() + next_pos++));
pos = next_pos;
next_pos = content.second.help_text.find('\n', pos);
}
help_text_lines.push_back(std::string(content.second.help_text.begin() + pos,
content.second.help_text.end()));
int arg_name_width = 16 - content.second.name.length();
arg_name_width = arg_name_width > 0 ? arg_name_width : 2;
std::cout << std::setw(4) << "-" << content.second.name << std::setw(arg_name_width)
<< " " << help_text_lines[0] << std::endl;
for(auto help_next_line = std::next(help_text_lines.begin());
help_next_line != help_text_lines.end();
++help_next_line)
{
std::cout << std::setw(28) << " " << *help_next_line << std::endl;
}
}
}
bool parse(int argc, char* argv[], int start_index = 1)
{
if(argc <= start_index)
{
// std::cout << "not enough args (" << argc << ") with starting index " << start_index
// << std::endl;
return true;
}
for(int i = start_index; i < argc; i++)
{
std::string cur_arg = std::string(argv[i]);
if(cur_arg[0] != '-')
{
std::cout << "illegal input" << std::endl;
usage();
return false;
}
else if(cur_arg[0] == '-' && cur_arg[1] == '?')
{
usage();
return false;
}
else
{
size_t found_equal = cur_arg.find('=');
if(found_equal == std::string::npos || found_equal == (cur_arg.length() - 1))
{
std::cout << "failed while parsing \"" << cur_arg << "\", "
<< "arg must be in the form \"-name=value\"" << std::endl;
return false;
}
std::string arg_name = cur_arg.substr(1, found_equal - 1);
std::string arg_value = cur_arg.substr(found_equal + 1);
if(arg_map.count(arg_name) == 0)
{
std::cout << "no such arg \"" << arg_name << "\" registered" << std::endl;
return false;
}
arg_map[arg_name].value = arg_value;
}
}
return true;
}
std::string get(const std::string& name) const { return get_str(name); }
std::string get_str(const std::string& name) const
{
assert(arg_map.count(name) != 0);
std::string value = arg_map.at(name).value;
return value;
}
int get_int(const std::string& name) const
{
assert(arg_map.count(name) != 0);
int value = atoi(arg_map.at(name).value.c_str());
return value;
}
uint32_t get_uint32(const std::string& name) const
{
assert(arg_map.count(name) != 0);
uint32_t value = strtoul(arg_map.at(name).value.c_str(), nullptr, 10);
return value;
}
uint64_t get_uint64(const std::string& name) const
{
assert(arg_map.count(name) != 0);
uint64_t value = strtoull(arg_map.at(name).value.c_str(), nullptr, 10);
return value;
}
double get_double(const std::string& name) const
{
assert(arg_map.count(name) != 0);
double value = atof(arg_map.at(name).value.c_str());
return value;
}
float get_float(const std::string& name) const
{
assert(arg_map.count(name) != 0);
float value = atof(arg_map.at(name).value.c_str());
return value;
}
private:
std::unordered_map<std::string, arg_content_t> arg_map;
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment