Commit 522b7aee authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2

parents ff936fd6 84832fc4
......@@ -35,15 +35,17 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
if(lengths.size() != NumDim1 + NumDim2)
{
std::ostringstream err;
err << "Incorrect number of lengths in " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__;
err << "Incorrect number of lengths in "
<< "device_contraction_utils.hpp"
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(strides.size() != NumDim1 + NumDim2)
{
std::ostringstream err;
err << "Incorrect number of strides in " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__;
err << "Incorrect number of strides in "
<< "device_contraction_utils.hpp"
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
using DeviceOp = DeviceGemm_Xdl_CShuffleV2;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v2<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
InMemoryDataOperationEnum::Set,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer,
ComputeTypeA,
ComputeTypeB>;
using Argument = typename GridwiseGemm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
float ave_time = 0;
const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1;
if(GridwiseGemm::CalculateKBlockLoopTailNum(K) == 3)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v2<GridwiseGemm, true>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v2<GridwiseGemm, true, 2>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off
str << "DeviceGemm_Xdl_CShuffleV2"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -174,6 +174,11 @@ struct PassThrough
{
y = x;
}
template <>
__host__ __device__ void operator()<int4_t, int>(int4_t& y, const int& x) const
{
y = type_convert<int4_t>(x);
}
#endif
template <>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -134,6 +134,11 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
: M_(M), N_(N), M01_(M01)
{
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
}
#endif
}
template <typename CGridDesc_M_N>
......@@ -252,6 +257,302 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo
BlockToCTileMap_M00_N0_M01Adapt;
};
// Rows of column-vectors
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt;
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock>
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default;
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(
const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default;
__host__ __device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default;
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt&
operator=(const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default;
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt&
operator=(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default;
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M,
index_t N,
index_t M01 = 8)
: M_(M), N_(N), M01_(M01)
{
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
}
#endif
}
template <typename CGridDesc_M_N>
__host__ __device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01 = 8)
: BlockToCTileMap_Grouped_M00_N0_M01Adapt(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
{
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return M0 * N0;
}
template <typename CGridDesc_M_N>
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum);
auto group_id = block_1d_id % GroupNum;
auto remap_block_1d_id = group_id * group_size + block_1d_id / GroupNum;
index_t idx_N0 = remap_block_1d_id % N0;
index_t idx_M0 = remap_block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
index_t idx_M00 = idx_M0 / M01_;
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t M_;
index_t N_;
index_t M01_;
};
// keep the redundant type argument for backward compatibility
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
: BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>
{
using BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>::
BlockToCTileMap_Grouped_M00_N0_M01Adapt;
};
// columns of row-vectors
// This C-tile map dynamically adjusts N01 when C-tile index is out of range
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
struct BlockToCTileMap_N00_M0_N01Adapt;
template <index_t MPerBlock, index_t NPerBlock>
struct BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock, void>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt() = default;
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const BlockToCTileMap_N00_M0_N01Adapt&) =
default;
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(BlockToCTileMap_N00_M0_N01Adapt&&) =
default;
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt&
operator=(const BlockToCTileMap_N00_M0_N01Adapt&) = default;
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt&
operator=(BlockToCTileMap_N00_M0_N01Adapt&&) = default;
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(index_t M, index_t N, index_t N01 = 8)
: M_(M), N_(N), N01_(N01)
{
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, N01 = %d\n", M_, N_, N01_);
}
#endif
}
template <typename CGridDesc_M_N>
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
index_t N01 = 8)
: BlockToCTileMap_N00_M0_N01Adapt(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), N01)
{
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return M0 * N0;
}
template <typename CGridDesc_M_N>
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
index_t idx_M0 = block_1d_id % M0;
index_t idx_N0 = block_1d_id / M0;
const auto N01_adapt = (idx_N0 < N0 - N0 % N01_) ? N01_ : N0 % N01_;
index_t idx_N00 = idx_N0 / N01_;
index_t idx_N01 = idx_N0 % N01_;
index_t idx_M0_N01_local = idx_M0 + idx_N01 * M0;
/**
* idxN0
*
* |< mtx N >|
*
* |<---N01--->|
* - |-----------|-----------|-----------|-----|-----|-
* ^ | 0 ----------> 1 | | | |
* | | / | | | | M_0 MPerBlock
* | / | | | |
* |------/----------------|-----------|-----|-----|-
* | | | | | | |
* idxM0 | V | | | | | M_1 MPerBlock
* | 2 ----------> 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | blockid | | | |
* | | 5 | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* Example:
* assume:
* N0 = 5
* M0 = 4
* block_1d_id = 5
* N01 = 2
*
* idx_M0 = 1
* idx_N0 = 1
* N01_adapt = 2
* idx_N00 = 0
* idx_N01 = 1
* idx_M0_N01_local = 5
* output {2, 1}
*/
return make_tuple(idx_M0_N01_local / N01_adapt,
idx_M0_N01_local % N01_adapt + idx_N00 * N01_);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t M_;
index_t N_;
index_t N01_;
};
// 2D slices of column-vectors in 3D space
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
......
......@@ -119,7 +119,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const InDataTypePointerTuple p_in_global_tuple,
XDataType* const __restrict__ p_x_lds,
XDataType* const __restrict__ p_x_lds_,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
......@@ -149,7 +149,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto x_lds_val_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_x_lds, x_grid_desc_m_k.GetElementSpaceSize() / grid_size);
p_x_lds_, x_grid_desc_m_k.GetElementSpaceSize() / grid_size);
auto in_thread_buf_tuple = generate_tuple(
[&](auto) {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseGemm, bool HasMainKBlockLoop, index_t TailNum = 3>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop, TailNum>(
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg);
#else
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1)
#endif
kernel_gemm_xdl_cshuffle_v2(const FloatA* p_a_grid,
const FloatB* p_b_grid,
FloatC* p_c_grid,
typename GridwiseGemm::Problem problem)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_a_grid, p_b_grid, p_c_grid, p_shared_0, p_shared_1, problem);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = problem;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename ALayout,
typename BLayout,
typename CLayout,
typename FloatA,
typename FloatB,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeTypeA = FloatC,
typename ComputeTypeB = ComputeTypeA>
struct GridwiseGemm_xdl_cshuffle_v2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateGridSize(index_t M, index_t N)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
}
__host__ static auto CalculateMPadded(index_t M)
{
return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
}
__host__ static auto CalculateNPadded(index_t N)
{
return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
}
__host__ static auto CalculateKPadded(index_t K)
{
return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
}
__host__ static auto CalculateAK0(index_t K)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
return CalculateKPadded(K) / AK1Value;
}
else
{
return K / AK1Value;
}
}
__host__ static auto CalculateBK0(index_t K)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
return CalculateKPadded(K) / BK1Value;
}
else
{
return K / BK1Value;
}
}
__host__ static auto CalculateMBlock(index_t M)
{
return math::integer_divide_floor(M, MPerBlock);
}
__host__ static auto CalculateNBlock(index_t N)
{
return math::integer_divide_floor(N, NPerBlock);
}
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
__host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
return transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(make_tuple(
Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
__device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(N, NPad - N),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
}
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
}
struct Problem
{
__host__ Problem(index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
KPadded{CalculateKPadded(K_)},
AK0{CalculateAK0(K_)},
BK0{CalculateBK0(K_)},
MBlock{CalculateMBlock(M_)},
NBlock{CalculateNBlock(N_)}
{
}
__host__ void Print() const
{
std::cout << "problem {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KP:" << KPadded << ", "
<< "AK0:" << AK0 << ", "
<< "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", "
<< "NBlock: " << NBlock << "}" << std::endl;
}
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t MPadded;
index_t NPadded;
index_t KPadded;
index_t AK0;
index_t BK0;
index_t MBlock;
index_t NBlock;
};
// Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem
{
__host__ Argument(const FloatA* p_a_grid_,
const FloatB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
{
}
const FloatA* p_a_grid;
const FloatB* p_b_grid;
FloatC* p_c_grid;
};
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1Number, AK1Number, I1));
}
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1Number, BK1Number, I1));
}
__device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) +
b_block_space_size_aligned * sizeof(ComputeTypeB)),
c_block_size * sizeof(FloatCShuffle));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ static constexpr bool CheckValidity(const Problem& problem)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(problem.M % MPerBlock == 0))
{
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(problem.N % NPerBlock == 0))
{
return false;
}
}
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding)
{
if(!(CalculateKPadded(problem.K) % AK1Value == 0) ||
!(CalculateKPadded(problem.K) % BK1Value == 0))
{
return false;
}
}
else
{
if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0))
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(problem.K % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(problem.M % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
if(problem.N % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(problem.K % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
// check gridwise gemm pipeline
const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock;
if(num_k_loop < 4)
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return num_loop > 3;
}
__host__ static constexpr index_t CalculateKBlockLoopTailNum(index_t K)
{
const index_t num_loop = K / KPerBlock;
if(num_loop % 2 == 1)
return 3;
else
return 2;
}
template <typename CGridDesc>
__device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
{
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
template <bool HasMainKBlockLoop, index_t TailNum = 3>
__device__ static void Run(const FloatA* p_a_grid,
const FloatB* p_b_grid,
FloatC* p_c_grid,
void* p_shared_0,
void* p_shared_1,
const Problem& problem)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
#if 0
if(threadIdx.x == 0){
printf("Hardware assigned No. %03d workgroup of logical C tile (%02d, %02d) on %d th XCC Die, %d th SE, %d th CU\n",
get_block_1d_id(),
block_work_idx[I0],
block_work_idx[I1],
__smid()>>6 & 0xf,
__smid()>>4 & 0x3,
__smid() & 0xf);
}
#endif
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatA,
ComputeTypeA,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatB,
ComputeTypeB,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
// BlockSize,
// ComputeType,
// FloatGemmAcc,
// decltype(a_block_desc_ak0_m_ak1),
// decltype(b_block_desc_bk0_n_bk1),
// MPerXdl,
// NPerXdl,
// MXdlPerWave,
// NXdlPerWave,
// KPack,
// LoopSched>();
auto blockwise_gemm_pipeline = BlockwiseGemmXdlops_pipeline_v4<
BlockSize,
ComputeTypeA,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>{}; // TransposeC
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeTypeA*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeTypeB*>(p_shared_0) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeTypeA*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeTypeB*>(p_shared_1) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
// gridwise GEMM pipeline
static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
// const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_bufs,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_bufs,
b_block_slice_copy_step,
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared_0),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
});
}
}
};
} // namespace ck
......@@ -268,6 +268,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding)
{
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
......@@ -329,6 +344,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding)
{
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
......
......@@ -328,7 +328,7 @@ struct WmmaSelector
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
static constexpr auto GetWmma<int4_t, int, 16, 16>()
static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
{
return WmmaInstr::wmma_i32_16x16x16_iu4;
}
......
......@@ -189,6 +189,7 @@ struct vector_type<T, 1>
}
};
int static err = 0;
template <typename T>
struct vector_type<T, 2>
{
......@@ -221,6 +222,10 @@ struct vector_type<T, 2>
{
return data_.d2x1_;
}
else
{
return err;
}
}
template <typename X>
......@@ -236,6 +241,10 @@ struct vector_type<T, 2>
{
return data_.d2x1_;
}
else
{
return err;
}
}
};
......@@ -278,6 +287,10 @@ struct vector_type<T, 4>
{
return data_.d4x1_;
}
else
{
return err;
}
}
template <typename X>
......@@ -298,6 +311,10 @@ struct vector_type<T, 4>
{
return data_.d4x1_;
}
else
{
return err;
}
}
};
......@@ -347,6 +364,10 @@ struct vector_type<T, 8>
{
return data_.d8x1_;
}
else
{
return err;
}
}
template <typename X>
......@@ -372,6 +393,10 @@ struct vector_type<T, 8>
{
return data_.d8x1_;
}
else
{
return err;
}
}
};
......@@ -428,6 +453,10 @@ struct vector_type<T, 16>
{
return data_.d16x1_;
}
else
{
return err;
}
}
template <typename X>
......@@ -458,6 +487,10 @@ struct vector_type<T, 16>
{
return data_.d16x1_;
}
else
{
return err;
}
}
};
......@@ -520,6 +553,10 @@ struct vector_type<T, 32>
{
return data_.d32x1_;
}
else
{
return err;
}
}
template <typename X>
......@@ -554,6 +591,10 @@ struct vector_type<T, 32>
{
return data_.d32x1_;
}
else
{
return err;
}
}
};
......@@ -623,6 +664,10 @@ struct vector_type<T, 64>
{
return data_.d64x1_;
}
else
{
return err;
}
}
template <typename X>
......@@ -662,6 +707,10 @@ struct vector_type<T, 64>
{
return data_.d64x1_;
}
else
{
return err;
}
}
};
......@@ -737,6 +786,10 @@ struct vector_type<T, 128>
{
return data_.d128x1_;
}
else
{
return err;
}
}
template <typename X>
......@@ -780,6 +833,10 @@ struct vector_type<T, 128>
{
return data_.d128x1_;
}
else
{
return err;
}
}
};
......@@ -861,6 +918,10 @@ struct vector_type<T, 256>
{
return data_.d256x1_;
}
else
{
return err;
}
}
template <typename X>
......@@ -908,6 +969,10 @@ struct vector_type<T, 256>
{
return data_.d256x1_;
}
else
{
return err;
}
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -19,6 +19,12 @@ struct is_known_at_compile_time<index_t>
static constexpr bool value = false;
};
template <>
struct is_known_at_compile_time<unsigned int>
{
static constexpr bool value = false;
};
template <>
struct is_known_at_compile_time<long_index_t>
{
......
......@@ -178,4 +178,15 @@ __host__ __device__ constexpr auto TupleDepth(const Tuple<Ts...>&)
return math::max(TupleDepth<depth + 1>(Ts{})...);
}
template <index_t from, index_t to, typename... Ts>
__host__ __device__ constexpr auto TupleSlice(const Tuple<Ts...>& tuple)
{
return generate_tuple(
[&](auto i) {
using Idx = Number<from + i>;
return tuple.At(Idx{});
},
Number<to - from>{});
}
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -14,24 +14,28 @@ namespace wrapper {
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). It is possible to pass nested shapes
* (e.g. ((4, 2), 2)), nested dimensions are merged.
* \tparam Strides Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). Stride tuple should be nested if shape tuple is
* nested.
* \tparam UnrolledDescriptorType Tensor descriptor for unnested shape dims.
*/
template <typename Shape, typename Strides>
template <typename Shape, typename UnrolledDescriptorType>
struct Layout
{
private:
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
// Generate default idxs tuple (idx with all merged nested shapes)
/**
* \brief Generate default indices tuple (idx with all merged nested shapes)
*
* \param shape Shape to align.
* \return Multi idx tuple with zeros.
*/
template <typename... Ts>
__host__ __device__ constexpr static auto GenerateDefaultIdxsTuple(const Tuple<Ts...>&)
__host__ __device__ constexpr static auto
GenerateDefaultIdxsTuple([[maybe_unused]] const Tuple<Ts...>& shape)
{
return generate_tuple(
[&](auto) {
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
{
// runtime layout
return index_t(0);
......@@ -45,32 +49,18 @@ struct Layout
Number<Tuple<Ts...>::Size()>{});
}
// Generate packed (column-major) strides if not passed
template <typename... Ts>
__host__ __device__ constexpr static auto
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
return generate_tuple(
[&](auto i) {
if constexpr(i.value == 0)
{
return I1;
}
else
{
return TupleReduce<I0.value, i.value>([](auto x, auto y) { return x * y; },
unrolled_shape);
}
},
Number<decltype(unrolled_shape)::Size()>{});
}
// Generate LowerDims in Compile-time for MergeTrasform using passed Type
// If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
// If tuple is element, then pass through (sequence with one element)
/**
* \brief Generate lower dims in compile-time for the Merge transform using
* provided type. If element of nested Tuple<Ts...> is also a tuple, then
* merge (generate sequence for merge). If tuple is element, then pass
* through (sequence with one element).
*
* \param shape Shape to align.
* \return LowerDims for MergeTrasform.
*/
template <typename Idx, typename... Ts>
__host__ __device__ constexpr static auto GenerateLowerDim(const Tuple<Ts...>&)
__host__ __device__ constexpr static auto
GenerateLowerDim([[maybe_unused]] const Tuple<Ts...>& shape)
{
if constexpr(Idx::value == 0)
{
......@@ -110,11 +100,17 @@ struct Layout
}
}
// Iterate over nested tuples in shape
// Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
// Example idx: (1, 1), 1, 1
// Example shape: (2, (2, 2)), 2, (2, 2)
// Unrolled shape: 2, (2, 2), 2, (2, 2)
/**
* \brief Iterate over the nested tuples in the shape.
* Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
* Example idx: (1, 1), 1, 1
* Example shape: (2, (2, 2)), 2, (2, 2)
* Unrolled shape: 2, (2, 2), 2, (2, 2)
*
* \param shape Layout shape.
* \param idx Idx to align.
* \return Algined shape.
*/
template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx)
......@@ -149,6 +145,13 @@ struct Layout
}
}
/**
* \brief Merge descriptor to 1D.
*
* \param shape Layout shape.
* \param desc Descriptor to merge.
* \return 1D descriptor.
*/
template <typename... ShapeDims, typename DescriptorToMerge>
__host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape,
const DescriptorToMerge& desc)
......@@ -160,18 +163,41 @@ struct Layout
const auto lower_dims = make_tuple(MergeElemsSequence::Reverse());
const auto upper_dims = make_tuple(Sequence<0>{});
// Merge to 1d
if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
{
return transform_tensor_descriptor(
desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
}
else
{
// If the descriptor is known at the compilation time,
// use `make_merge_transform_v1_carry_check` because it doesn't use
// memcpy.
return transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform_v1_carry_check(merge_elems)),
lower_dims,
upper_dims);
}
}
// Merge nested shape dims when corresponding index is also nested.
// Input desc shape: 2, 2, 2, 2, 2, 2
// Example idx: 1, 1, 1, 1
// Example shape: 2, (2, 2), 2, (2, 2)
// Merged shape: 2, 4, 2, 4
/**
* \brief Merge nested shape dims when corresponding index is also merged.
* Input desc shape: 2, 2, 2, 2, 2, 2
* Example idx: 1, 1, 1, (1, 1)
* Example shape: 2, (2, 2), 2, (2, 2)
* Merged shape: 2, 4, 2, 2, 2
*
* \param shape Layout shape.
* \param idxs Indexes to align descriptor.
* \param desc Descriptor to merge.
* \return Aligned descriptor to idx.
*/
template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge>
__host__ __device__ constexpr static auto CreateMergedDescriptor(
const Tuple<ShapeDims...>& shape, const Tuple<IdxDims...>&, DescriptorToMerge& desc)
__host__ __device__ constexpr static auto
CreateMergedDescriptor(const Tuple<ShapeDims...>& shape,
[[maybe_unused]] const Tuple<IdxDims...>& idxs,
DescriptorToMerge& desc)
{
const auto transforms = generate_tuple(
[&](auto i) {
......@@ -183,9 +209,19 @@ struct Layout
// If shape element is tuple and idx element is Number, then merge
// Unroll and reverse tuple to traverse column-major
const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i)));
if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
{
return make_merge_transform(merge_elems);
}
else
{
// If the descriptor is known at the compilation time,
// use `make_merge_transform_v1_carry_check` because
// it doesn't use memcpy.
return make_merge_transform_v1_carry_check(merge_elems);
}
}
else
{
// If shape element is integer and idx element is tuple, passed idx is wrong
static_assert(
......@@ -207,33 +243,24 @@ struct Layout
return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
}
template <typename LayoutShape, typename LayoutStrides>
__host__ __device__ static auto MakeFlattenDescriptor(const LayoutShape& shape,
const LayoutStrides& strides)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
const auto unrolled_strides = UnrollNestedTuple(strides);
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
"Size of strides and shape are not consistent.");
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
}
// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
using DeducedStrides =
std::conditional_t<is_same_v<Strides, Tuple<>>,
remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
Strides>;
using FlattenDescriptorType =
remove_cvref_t<decltype(MakeFlattenDescriptor(Shape{}, DeducedStrides{}))>;
using Descriptor1dType =
remove_cvref_t<decltype(MakeMerge1d(Shape{}, FlattenDescriptorType{}))>;
remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnrolledDescriptorType{}))>;
using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
public:
/**
* \brief Transform descriptor to align to passed indexes.
*
* \param shape Layout shape.
* \param idxs Indexes to align descriptor.
* \param naive_descriptor Descriptor to merge.
* \return Aligned descriptor to idx.
*/
template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr static auto
TransformDesc(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx,
const FlattenDescriptorType& naive_descriptor)
const Tuple<IdxDims...>& idxs,
const UnrolledDescriptorType& naive_descriptor)
{
if constexpr(Tuple<IdxDims...>::Size() == I1)
{
......@@ -249,55 +276,38 @@ struct Layout
static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
"Idx rank and Shape rank must be the same (except 1d).");
// Unroll while IdxDims is nested
const auto aligned_shape = AlignShapeToIdx(shape, idx);
const auto aligned_shape = AlignShapeToIdx(shape, idxs);
// Transform correct form of shape
return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), naive_descriptor);
return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idxs), naive_descriptor);
}
}
using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc(
Shape{}, DefaultIdxsTupleType{}, FlattenDescriptorType{}))>;
Shape{}, DefaultIdxsTupleType{}, UnrolledDescriptorType{}))>;
public:
__host__ __device__ constexpr auto GetElementSpaceSize() const
{
return flatten_descriptor_.GetElementSpaceSize();
return unrolled_descriptor_.GetElementSpaceSize();
}
__host__ __device__ Layout() = delete;
/**
* \brief Layout constructor.
*
* \param shape Shape for layout.
* \param strides Strides for layout (optional if tensor is packed).
* \param unnested_descriptor Descriptor
*/
__host__ __device__ constexpr Layout(const Shape& shape, const Strides& strides)
: flatten_descriptor_{}, shape_(shape), strides_(strides)
__host__ __device__ constexpr Layout(const Shape& shape,
const UnrolledDescriptorType& unnested_descriptor)
: unrolled_descriptor_(unnested_descriptor), shape_(shape)
{
// Construct if runtime mode
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
{
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
merged_nests_descriptor_ =
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
}
}
/**
* \brief Layout constructor (with default packed column-major strides).
*
* \param shape Shape for layout.
*/
__host__ __device__ constexpr Layout(const Shape& shape)
: flatten_descriptor_{}, shape_(shape), strides_(GenerateColumnMajorPackedStrides(shape_))
{
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
{
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
descriptor_1d_ = MakeMerge1d(shape_, unrolled_descriptor_);
merged_nests_descriptor_ =
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
TransformDesc(shape_, DefaultIdxsTupleType{}, unrolled_descriptor_);
}
}
......@@ -310,9 +320,9 @@ struct Layout
template <typename Idxs>
__host__ __device__ constexpr index_t operator()() const
{
static_assert(FlattenDescriptorType::IsKnownAtCompileTime(),
static_assert(remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime(),
"Compiletime operator used on runtime layout.");
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, FlattenDescriptorType{}));
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnrolledDescriptorType{}));
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
}
......@@ -339,7 +349,7 @@ struct Layout
else
{
// Custom index, need to transform descriptor
const auto transformed_desc = TransformDesc(shape_, Idx, flatten_descriptor_);
const auto transformed_desc = TransformDesc(shape_, Idx, unrolled_descriptor_);
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
}
}
......@@ -351,7 +361,7 @@ struct Layout
* \return Calculated size.
*/
template <index_t IDim>
__host__ __device__ constexpr index_t GetLength() const
__host__ __device__ constexpr auto GetLength() const
{
const auto elem = shape_.At(Number<IDim>{});
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
......@@ -371,7 +381,7 @@ struct Layout
*
* \return Calculated size.
*/
__host__ __device__ constexpr index_t GetLengths() const
__host__ __device__ constexpr auto GetLengths() const
{
const auto unrolled_shape = UnrollNestedTuple(shape_);
return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
......@@ -385,13 +395,6 @@ struct Layout
*/
__host__ __device__ constexpr const Shape& GetShape() const { return shape_; }
/**
* \brief Strides getter.
*
* \return Strides.
*/
__host__ __device__ constexpr const DeducedStrides& GetStrides() const { return strides_; }
/**
* \brief Get default lengths (tuple filled with Shape length elements).
*
......@@ -413,21 +416,56 @@ struct Layout
}
/**
* \brief Get default descriptor (with the same size as Shape)
* \brief Get descriptor with all nested dimensions merged.
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (4, 2)
*
* \note The size of merged descriptor is the same as Layout's shape.
*
* \return Default descriptor.
* \return Merged nests descriptor.
*/
__host__ __device__ constexpr MergedNestsDescriptorType GetDefaultDescriptor()
__host__ __device__ constexpr const MergedNestsDescriptorType&
GetMergedNestingDescriptor() const
{
return merged_nests_descriptor_;
}
/**
* \brief Get descriptor with all dimensions are merged (1D).
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (8)
*
* \return 1D descriptor.
*/
__host__ __device__ constexpr const Descriptor1dType& Get1DDescriptor() const
{
return descriptor_1d_;
}
/**
* \brief Get unnested descriptor (with unrolled dims)
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (2, 2, 2)
*
* \return Flattened descriptor.
*/
__host__ __device__ constexpr const UnrolledDescriptorType& GetUnrolledDescriptor() const
{
return unrolled_descriptor_;
}
private:
FlattenDescriptorType flatten_descriptor_;
// All dimensions are unrolled
UnrolledDescriptorType unrolled_descriptor_;
// 1D descriptor
Descriptor1dType descriptor_1d_;
// All nesting are merged
MergedNestsDescriptorType merged_nests_descriptor_;
// Example, shape: ((2, 2), 2)
// UnrolledDescriptorType lengths: (2, 2, 2)
// Descriptor1dType lengths: (8)
// MergedNestsDescriptorType lengths: (4, 2)
const Shape shape_;
const DeducedStrides strides_;
};
} // namespace wrapper
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "../utils/tensor_utils.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace wrapper {
/**
* \brief Perform generic copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
*
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template <typename SrcTensorType, typename DstTensorType>
__host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
{
if constexpr(!SrcTensorType::IsDynamicBuffer)
{
using SizeType = decltype(size(src_tensor));
static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
}
else if constexpr(!DstTensorType::IsDynamicBuffer)
{
using SizeType = decltype(size(dst_tensor));
static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
}
else
{
for(int i = 0; i < size(src_tensor); i++)
{
dst_tensor(i) = src_tensor(i);
}
}
}
/**
* \brief Perform optimized copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
*
* \tparam DimAccessOrderTuple Tuple with dimension access order.
* \tparam VectorDim Dimension for vectorized read and write.
* \tparam ScalarPerVector Number of scalar per vectorized read and write.
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template <typename DimAccessOrderTuple,
index_t VectorDim,
index_t ScalarPerVector,
typename SrcTensorType,
typename DstTensorType>
__device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
{
static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
const auto& in_grid_desc = layout(src_tensor).GetUnrolledDescriptor();
const auto& out_grid_desc = layout(dst_tensor).GetUnrolledDescriptor();
using SrcShapeType = remove_cvref_t<decltype(shape(src_tensor))>;
constexpr index_t num_dims = SrcShapeType::Size();
constexpr auto thread_slice_lengths =
generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
constexpr auto dim_access_order = generate_sequence_v2(
[](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});
if constexpr(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer)
{
// Perform a copy between DynamicBuffers
auto transfer = ThreadwiseTensorSliceTransfer_v7<
Tuple<typename SrcTensorType::TensorElementType>,
Tuple<typename DstTensorType::TensorElementType>,
decltype(tie(in_grid_desc)),
decltype(tie(out_grid_desc)),
tensor_operation::element_wise::PassThrough,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
decltype(thread_slice_lengths),
decltype(dim_access_order),
VectorDim,
ScalarPerVector,
Sequence<false>,
Sequence<false>>{in_grid_desc,
make_tuple(src_tensor.GetMultiIdxOffsets()),
out_grid_desc,
make_tuple(dst_tensor.GetMultiIdxOffsets()),
tensor_operation::element_wise::PassThrough{}};
transfer.Run(tie(in_grid_desc),
tie(src_tensor.GetBuffer()),
tie(out_grid_desc),
tie(dst_tensor.GetBuffer()));
}
else if constexpr(!SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer)
{
// Perform copy from StaticBuffer to DynamicBuffer
const auto src_slice_origin_idxs =
generate_tuple([&](auto) { return I0; }, Number<num_dims>{});
auto transfer =
ThreadwiseTensorSliceTransfer_v1r3<typename SrcTensorType::TensorElementType,
typename DstTensorType::TensorElementType,
remove_cvref_t<decltype(in_grid_desc)>,
remove_cvref_t<decltype(out_grid_desc)>,
tensor_operation::element_wise::PassThrough,
decltype(thread_slice_lengths),
decltype(dim_access_order),
VectorDim,
ScalarPerVector,
InMemoryDataOperationEnum::Set,
I1,
true>{out_grid_desc,
dst_tensor.GetMultiIdxOffsets(),
tensor_operation::element_wise::PassThrough{}};
transfer.Run(in_grid_desc,
src_slice_origin_idxs,
src_tensor.GetBuffer(),
out_grid_desc,
dst_tensor.GetBuffer());
}
else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer)
{
// Perform copy from DynamicBuffer to StaticBuffer
const auto src_dst_slice_origin =
generate_tuple([&](auto) { return I0; }, Number<num_dims>{});
constexpr auto src_vector_tensor_lengths = generate_sequence_v2(
[&](auto I) {
if constexpr(I == VectorDim)
{
return Number<ScalarPerVector>{};
}
else
{
return I1;
}
},
Number<num_dims>{});
auto transfer =
ThreadwiseTensorSliceTransfer_v4r1<typename SrcTensorType::TensorElementType,
typename DstTensorType::TensorElementType,
remove_cvref_t<decltype(in_grid_desc)>,
remove_cvref_t<decltype(out_grid_desc)>,
decltype(thread_slice_lengths),
decltype(dim_access_order),
decltype(src_vector_tensor_lengths),
decltype(dim_access_order)>{
src_tensor.GetMultiIdxOffsets()};
transfer.Run(in_grid_desc,
src_dst_slice_origin,
src_tensor.GetBuffer(),
out_grid_desc,
src_dst_slice_origin,
dst_tensor.GetBuffer());
}
else
{
// Perform copy between StaticBuffers
copy(src_tensor, dst_tensor);
}
}
} // namespace wrapper
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "utils/tensor_utils.hpp"
#include "utils/tensor_partition.hpp"
#include "utils/layout_utils.hpp"
namespace ck {
namespace wrapper {
namespace detail {
namespace {
/**
* \brief Tensor wrapper that performs static and dynamic buffer logic.
* \brief Check if Tuple contains Slice object
*
* \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR).
* \tparam ElementType Element data type.
* \tparam Shape Tensor shape (layout component).
* \tparam Strides Tensor strides (layout component).
* \tparam NumVectors Number of vectors (only for VGPR, SGPR).
* \tparam ScalarPerVector Scalars per vector (only for VGPR, SGPR).
* \return True if tuple contains Slice object.
*/
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors, // param for Register memory
index_t ScalarPerVector // param for Register memory
>
struct Tensor
template <typename T>
__host__ __device__ constexpr bool HasSlice(T&&)
{
private:
// Check if Tuple contains Slice object
template <typename T>
constexpr static bool IsSlicing(T&&)
{
return is_detected<is_slice, T>::value;
}
template <typename... Ts>
constexpr static bool IsSlicing(Tuple<Ts...>&&)
{
return (IsSlicing(Ts{}) || ...);
}
// Calculate first index of new tensor after slice
// It is needed to calculate offset for new tensor
template <typename... Ts>
constexpr auto GetStartIdxForSlicedTensor(const Tuple<Ts...>& idx) const
{
const auto start_idx_for_sliced_tensor = generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// if tuple then recurrence
return GetStartIdxForSlicedTensor(idx.At(num_i));
}
else if constexpr(is_detected<is_slice,
tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// if slice, return the beginning of the interval
return idx.At(num_i).from_;
}
else
{
// if one dim selected
return idx.At(num_i);
}
},
Number<Tuple<Ts...>::Size()>{});
return start_idx_for_sliced_tensor;
}
}
template <typename... Ts>
__host__ __device__ constexpr bool HasSlice(Tuple<Ts...>&&)
{
return (HasSlice(Ts{}) || ...);
}
// Calculate new tensor shape after slice
template <typename... Ts, typename ShapeTmpType>
constexpr auto GetShapeFromSlicedTensor(const Tuple<Ts...>& idx,
const ShapeTmpType& shape) const
{
/**
* \brief Calculate new shape after slice from parent shape.
*
* \param idxs Tuple of indexes defining slice ranges.
* \param shape Shape which will be sliced.
* \return New tensor shape.
*/
template <typename... Ts, typename SlicedShape>
__host__ __device__ constexpr auto GetSlicedShape(const Tuple<Ts...>& idxs,
const SlicedShape& shape)
{
// Pack each value in tuple to remove empty tuples after generation
auto new_shape = generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
if constexpr(!IsSlicing(tuple_element_t<i.value, Tuple<Ts...>>{}))
if constexpr(!detail::HasSlice(tuple_element_t<i.value, Tuple<Ts...>>{}))
{
// if tuple does not have any slice then we can remove dimension
return Tuple<>{};
......@@ -90,15 +53,14 @@ struct Tensor
else
{
// if tuple then recurrence
return make_tuple(GetShapeFromSlicedTensor(idx.At(num_i), shape.At(num_i)));
return make_tuple(GetSlicedShape(idxs.At(num_i), shape.At(num_i)));
}
}
else if constexpr(is_detected<is_slice,
tuple_element_t<i.value, Tuple<Ts...>>>::value)
else if constexpr(is_detected<is_slice, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// calculate new dimension
const auto& dim = size(shape.At(num_i));
const auto val = idx.At(num_i).range(dim);
const auto val = idxs.At(num_i).range(dim);
return make_tuple(val);
}
else
......@@ -110,50 +72,143 @@ struct Tensor
Number<Tuple<Ts...>::Size()>{});
// Remove empty tuples (deleted elements) and return
return UnrollNestedTuple<0, 1>(new_shape);
}
}
template <typename... Ts, typename StridesTmpType>
constexpr auto GetStridesFromSlicedTensor(const Tuple<Ts...>& idx,
const StridesTmpType& strides) const
{
/**
* \brief Generate Freeze for each of nested shape.
*
* \param idx Tuple of start indices for slice.
* \param shape Shape which will be freezed.
* \return Generated freeze transforms.
*/
template <typename T, typename Shape>
__host__ __device__ constexpr auto GenerateMultipleFreeze(T idx, const Shape& shape)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
return generate_tuple(
[&](auto i) {
// dimension offset from idx
const auto dim = unrolled_shape.At(Number<i>{});
const auto dim_idx = idx % dim;
idx /= dim;
return make_freeze_transform(dim_idx);
},
Number<decltype(unrolled_shape)::Size()>{});
}
/**
* \brief Generate transforms for slice tensor.
*
* \param idx Tuple of start indices for slice.
* \param shape Shape which will be sliced.
* \return Generated transforms.
*/
template <typename... Ts, typename Shape>
__host__ __device__ constexpr auto GenerateSliceTransforms(const Tuple<Ts...>& idx,
const Shape& shape)
{
// Pack each value in tuple to remove empty tuples after generation
auto new_strides = generate_tuple(
auto transforms = generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
if constexpr(!IsSlicing(tuple_element_t<i.value, Tuple<Ts...>>{}))
{
// if tuple does not have any slice then we can remove dimension
return Tuple<>{};
return GenerateSliceTransforms(idx.At(num_i), shape.At(num_i));
}
else
else if constexpr(is_detected<is_slice, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// if tuple then recurrence
return make_tuple(
GetStridesFromSlicedTensor(idx.At(num_i), strides.At(num_i)));
}
}
else if constexpr(is_detected<is_slice,
tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// Stride will be the same
return make_tuple(strides.At(num_i));
const auto from = idx.At(num_i).from_;
const auto dim = size<num_i>(shape);
const auto range = idx.At(num_i).range(dim);
return make_slice_transform(range, from, from + range);
}
else
{
// remove dimension for just value
return Tuple<>{};
return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i));
}
},
Number<Tuple<Ts...>::Size()>{});
// Remove empty tuples (deleted elements) and return
return UnrollNestedTuple<0, 1>(new_strides);
return UnrollNestedTuple(transforms);
}
template <index_t i, typename LowerIndex>
__host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze<LowerIndex>&)
{
// There is no output for Freeze transform
return Sequence<>{};
}
template <index_t i, typename LowLength, typename SliceBegin, typename SliceEnd>
__host__ __device__ constexpr auto GetSequenceVal(const ck::Slice<LowLength, SliceBegin, SliceEnd>&)
{
return Sequence<i>{};
}
template <index_t i>
__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&)
{
return Tuple<>{};
}
template <index_t i, typename... Transforms>
__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>& transforms)
{
constexpr auto num_transforms = Tuple<Transforms...>::Size();
// Deduce Sequence element for specific transform
const auto current_elem = GetSequenceVal<i>(transforms.At(Number<0>{}));
if constexpr(is_same_v<decltype(current_elem), const Sequence<>>)
{
const auto next_tuple = GenerateUpperDims<i>(TupleSlice<1, num_transforms>(transforms));
return concat_tuple(make_tuple(current_elem), next_tuple);
}
else
{
// Increase i if current_elem is Slice transform
const auto next_tuple = GenerateUpperDims<i + 1>(TupleSlice<1, num_transforms>(transforms));
return concat_tuple(make_tuple(current_elem), next_tuple);
}
}
template <typename... Ts, typename Shape, typename FlattenDescriptor>
__host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& idx,
const Shape& shape,
const FlattenDescriptor& flatten_desc)
{
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
const auto transforms = GenerateSliceTransforms(idx, shape);
using TransformsTupleType = decltype(transforms);
const auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
}
} // namespace
} // namespace detail
/**
* \brief Tensor wrapper that performs static and dynamic buffer logic.
* The tensor is based on a descriptor stored in the Layout. Additionally,
* tensor can be sliced or shifted using multi-index offset.
*
* \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR).
* \tparam ElementType Element data type.
* \tparam Shape Tensor shape (layout component).
* \tparam UnrolledDescriptorType Flatten descriptor (layout component).
*/
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename UnrolledDescriptorType>
struct Tensor
{
public:
using ElementSpaceSize = decltype(Layout<Shape, Strides>{
Shape{}, Strides{}}.GetElementSpaceSize()); // SpaceSize type for buffer
using ElementSpaceSize = decltype(Layout<Shape, UnrolledDescriptorType>{
Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer
using TensorElementType = ElementType; // DataType
static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace;
......@@ -161,135 +216,178 @@ struct Tensor
BufferAddressSpace == MemoryTypeEnum ::Vgpr);
__host__ __device__ Tensor() = delete;
__host__ __device__ Tensor(ElementType* pointer, const Layout<Shape, Strides>& layout)
__host__ __device__ constexpr Tensor(ElementType* pointer,
const Layout<Shape, UnrolledDescriptorType>& layout)
: layout_(layout),
buffer_(make_dynamic_buffer<BufferAddressSpace>(pointer, layout.GetElementSpaceSize()))
buffer_(make_dynamic_buffer<BufferAddressSpace>(pointer, layout.GetElementSpaceSize())),
multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
base_offset_(0)
{
static_assert(IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
}
__host__ __device__ Tensor(const Layout<Shape, Strides>& layout) : layout_(layout)
__host__ __device__ constexpr Tensor(const Layout<Shape, UnrolledDescriptorType>& layout)
: layout_(layout),
multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
base_offset_(0)
{
static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
}
__host__ __device__ constexpr const Layout<Shape, Strides>& GetLayout() const
__host__ __device__ constexpr const Layout<Shape, UnrolledDescriptorType>& GetLayout() const
{
return layout_;
}
// Getter for new sliced tensor
template <typename... Ts, enable_if_t<IsSlicing(Tuple<Ts...>{}), bool> = false>
__host__ __device__ auto operator[](const Tuple<Ts...>& idx) const
/**
* \brief Get the new sliced tensor.
*
* \param idx Tuple of indices: slice(from,to) or scalar.
* \return Sliced tensor.
*/
template <typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ auto operator[](const Tuple<Ts...>& idx)
{
static_assert(IsDynamicBuffer, "Register slice is not supported");
// Calculate offset based on first idx for new tensor
const index_t offset = layout_(GetStartIdxForSlicedTensor(idx));
const auto& shape = layout_.GetShape();
auto new_shape = detail::GetSlicedShape(idx, shape);
auto new_shape = GetShapeFromSlicedTensor(idx, layout_.GetShape());
if constexpr(is_same_v<Strides, Tuple<>>)
{
auto new_layout = make_layout(new_shape);
return make_tensor<BufferAddressSpace>(buffer_.p_data_ + offset, new_layout);
}
else
{
auto new_strides = GetStridesFromSlicedTensor(idx, layout_.GetStrides());
auto new_layout = make_layout(new_shape, new_strides);
return make_tensor<BufferAddressSpace>(buffer_.p_data_ + offset, new_layout);
}
const auto& flatten_desc = layout_.GetUnrolledDescriptor();
auto new_desc = detail::GenerateSlicedDescriptor(idx, shape, flatten_desc);
const auto new_layout =
Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
// Update embed offset
base_offset_ -= new_layout(make_tuple(Number<0>{}));
return make_tensor<BufferAddressSpace>(buffer_.p_data_, new_layout);
}
template <typename... Ts, enable_if_t<IsSlicing(Tuple<Ts...>{}), bool> = false>
__host__ __device__ auto operator()(const Tuple<Ts...>& idx) const
template <typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ auto operator()(const Tuple<Ts...>& idx)
{
return this->operator[](idx);
}
template <typename... Idxs, enable_if_t<IsSlicing(Tuple<Idxs...>{}), bool> = false>
__host__ __device__ auto operator()(Idxs... idxs) const
template <typename... Idxs, enable_if_t<detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
__host__ __device__ auto operator()(Idxs... idxs)
{
return this->operator[](make_tuple(idxs...));
}
// Getter for the const value
template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false>
/**
* \brief Getter of the tensor's const value reference.
*
* \param idx Tuple of indices.
* \return Requested value.
*/
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ const ElementType& operator[](const Tuple<Ts...>& idx) const
{
if constexpr(IsDynamicBuffer)
{
const index_t offset = layout_(idx);
const index_t offset = layout_(idx) + base_offset_;
return buffer_[offset];
}
else
{
if constexpr(is_same_v<Strides, Tuple<>>)
{
constexpr index_t offset =
Layout<Shape, Strides>{Shape{}}.template operator()<Tuple<Ts...>>();
return buffer_[Number<offset>{}];
}
else
{
constexpr index_t offset =
Layout<Shape, Strides>{Shape{}, Strides{}}.template operator()<Tuple<Ts...>>();
return buffer_[Number<offset>{}];
}
constexpr index_t index_offset = Layout<Shape, UnrolledDescriptorType>{
Shape{},
UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
// Calculate and apply base offset in compile-time
constexpr index_t base_offset = Layout<Shape, UnrolledDescriptorType>{
Shape{},
UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
return buffer_[Number<index_offset + base_offset>{}];
}
}
template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false>
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ const ElementType& operator()(const Tuple<Ts...>& idx) const
{
return this->operator[](idx);
}
template <typename... Idxs, enable_if_t<!IsSlicing(Tuple<Idxs...>{}), bool> = false>
template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
__host__ __device__ const ElementType& operator()(Idxs... idxs) const
{
return this->operator[](make_tuple(idxs...));
}
// Getter for the value reference
template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false>
/**
* \brief Getter of tensor value reference.
*
* \param idx Tuple of indices.
* \return Requested value.
*/
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ ElementType& operator[](const Tuple<Ts...>& idx)
{
if constexpr(IsDynamicBuffer)
{
const index_t offset = layout_(idx);
const index_t offset = layout_(idx) + base_offset_;
return buffer_(offset);
}
else
{
if constexpr(is_same_v<Strides, Tuple<>>)
{
constexpr index_t offset =
Layout<Shape, Strides>{Shape{}}.template operator()<Tuple<Ts...>>();
return buffer_(Number<offset>{});
}
else
{
constexpr index_t offset =
Layout<Shape, Strides>{Shape{}, Strides{}}.template operator()<Tuple<Ts...>>();
return buffer_(Number<offset>{});
}
constexpr index_t index_offset = Layout<Shape, UnrolledDescriptorType>{
Shape{},
UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
// Apply embed offset (calculate in compiletime)
constexpr index_t base_offset = Layout<Shape, UnrolledDescriptorType>{
Shape{},
UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
return buffer_(Number<index_offset + base_offset>{});
}
}
template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false>
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ ElementType& operator()(const Tuple<Ts...>& idx)
{
return this->operator[](idx);
}
template <typename... Idxs, enable_if_t<!IsSlicing(Tuple<Idxs...>{}), bool> = false>
template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
__host__ __device__ ElementType& operator()(Idxs... idxs)
{
return this->operator[](make_tuple(idxs...));
}
__host__ __device__ constexpr auto GetDefaultDescriptor()
/**
* \brief Get descriptor with all nested dimensions merged.
*
* \return Merged nests descriptor.
*/
__host__ __device__ constexpr auto GetMergedNestingDescriptor()
{
return layout_.GetMergedNestingDescriptor();
}
/**
* \brief Get pointer to the data.
*
* \return Pointer.
*/
__host__ __device__ ElementType* GetPointer() const { return buffer_.p_data_; }
__host__ __device__ constexpr auto& GetBuffer() { return buffer_; }
__host__ __device__ constexpr auto& GetBuffer() const { return buffer_; }
/**
* \brief Get multi index offset to the data.
*
* \return Multi index offset.
*/
__host__ __device__ constexpr auto& GetMultiIdxOffsets() const { return multi_idx_offset_; }
/**
* \brief Apply multi index offset on the tensor.
*
* \param multi_idx_offset Multi index offset.
*/
template <typename MultiIdxOffsets>
__host__ __device__ constexpr void SetMultiIdxOffset(const MultiIdxOffsets multi_idx_offset)
{
return layout_.GetDefaultDescriptor();
multi_idx_offset_ = multi_idx_offset;
base_offset_ += layout_(multi_idx_offset);
}
private:
......@@ -297,17 +395,28 @@ struct Tensor
ElementType,
ElementSpaceSize,
true /*InvalidElementUseNumericalZeroValue*/>;
using StaticBufferType =
StaticBufferTupleOfVector<BufferAddressSpace,
using StaticBufferType = StaticBuffer<BufferAddressSpace,
ElementType,
NumVectors,
ScalarPerVector,
size(Shape{}),
true /*InvalidElementUseNumericalZeroValue*/>;
// If register use static buffer, else use dynamic buffer
using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;
const Layout<Shape, Strides> layout_;
const Layout<Shape, UnrolledDescriptorType> layout_;
Buffer buffer_;
// We use multi_idx_offset_ to enable the creation of a descriptor in
// compile time for partitions or tiles if tile shape and thread layout
// is known at compile time (We can use the same descriptor for each
// thread). Additionally, the copy between the static and dynamic buffer
// requires a descriptor known at compile time, so we can shift data using
// such multi_idx_offset_.
MultiIndex<Shape::Size()> multi_idx_offset_;
// Base offset and multi index offset are corresponding to exactly the
// same element in tensor ( and in physical memory ). Multi index offset
// is multi dimensional index. However base offset is calculated using
// tensor descriptor (thus all it's transforms) and is linear (1D).
// We store base_offset_ to avoid multiple recalculations.
index_t base_offset_;
};
} // namespace wrapper
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -22,11 +22,69 @@ namespace wrapper {
// Disable from doxygen docs generation
/// @cond
// forward declaration
template <typename Shape, typename Strides>
template <typename Shape, typename UnrolledDescriptorType>
struct Layout;
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
namespace {
/**
* \brief Generate packed (column-major) strides if not passed
*
* \param shape Tensor shape.
* \return Generated column-major strides.
*/
template <typename... Ts>
__host__ __device__ constexpr static auto
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
return generate_tuple(
[&](auto i) {
if constexpr(i.value == 0)
{
return Number<1>{};
}
else
{
return TupleReduce<Number<0>{}.value, i.value>([](auto x, auto y) { return x * y; },
unrolled_shape);
}
},
Number<decltype(unrolled_shape)::Size()>{});
}
/**
* \brief Create naive tensor descriptor from nested shape.
*
* \param shape Tensor shape.
* \param strides Tensor strides.
* \return Unrolled descriptor
*/
template <typename LayoutShape, typename LayoutStrides>
__host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& shape,
const LayoutStrides& strides)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
if constexpr(is_same_v<LayoutStrides, Tuple<>>)
{
// if not passed, then generate
const auto unrolled_strides = GenerateColumnMajorPackedStrides(unrolled_shape);
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
"Size of strides and shape are not consistent.");
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
}
else
{
const auto unrolled_strides = UnrollNestedTuple(strides);
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
"Size of strides and shape are not consistent.");
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
}
}
} // namespace
/// @endcond
// make_*
......@@ -38,10 +96,10 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
* \return Constructed layout.
*/
template <typename Shape, typename Strides>
__host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& shape,
const Strides& strides)
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
{
return Layout<Shape, Strides>(shape, strides);
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{}));
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, strides));
}
/**
......@@ -52,16 +110,21 @@ __host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& sh
* \return Constructed layout.
*/
template <typename Shape>
__host__ __device__ constexpr Layout<Shape, Tuple<>> make_layout(const Shape& shape)
__host__ __device__ constexpr auto make_layout(const Shape& shape)
{
return Layout<Shape, Tuple<>>(shape);
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, Tuple<>{}));
}
// Layout helpers
// get
// Get dim (could be returned from get with empty Idxs)
/**
* \private
* \brief Get dim.
*
* \param dim Dimension.
* \return Returned the same dimension.
*/
template <typename T>
__host__ __device__ T constexpr get(const T& dim)
......@@ -89,26 +152,51 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
* \param layout Layout to create sub layout.
* \return Requsted sub layout.
*/
template <index_t idx, typename Shape, typename Strides>
__host__ __device__ constexpr auto get(const Layout<Shape, Strides>& layout)
template <index_t idx, typename Shape, typename FlattenDesc>
__host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
{
const auto& shape = layout.GetShape();
const auto& new_shape = get<idx>(shape);
const auto new_shape = get<idx>(shape);
static_assert(is_detected<is_tuple, decltype(new_shape)>::value,
"Shape of sub layout must be tuple");
if constexpr(is_same_v<Strides, Tuple<>>)
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
constexpr auto new_shape_dims = decltype(UnrollNestedTuple(new_shape))::Size();
constexpr auto shape_offset = decltype(UnrollNestedTuple(TupleSlice<0, idx>(shape)))::Size();
const auto unrolled_shape = UnrollNestedTuple(shape);
const auto transforms = generate_tuple(
[&](auto i) {
// Compare Idx with shape
if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
{
// Remove dimension
return make_freeze_transform(Number<0>{});
}
else
{
// If stride not passed, create without strides
return make_layout(new_shape);
return make_pass_through_transform(unrolled_shape.At(i));
}
},
Number<old_shape_dims>{});
const auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
const auto upper_dims = generate_tuple(
[&](auto i) {
if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
return Sequence<>{};
else
{
const auto& strides = layout.GetStrides();
const auto& new_strides = get<idx>(strides);
static_assert(is_detected<is_tuple, decltype(new_strides)>::value,
"Strides of sub layout must be tuple");
return make_layout(new_shape, new_strides);
return Sequence<i.value - shape_offset>{};
}
},
Number<old_shape_dims>{});
const auto& flatten_desc = layout.GetUnrolledDescriptor();
auto new_desc = transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
return Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
}
/**
......@@ -125,9 +213,12 @@ __host__ __device__ constexpr auto get(const T& elem)
}
// size
// Get dim size (could be returned from get function)
/**
* \private
* \brief Get size.
*
* \param dim Size.
* \return Returned the same size.
*/
template <typename T>
__host__ __device__ T constexpr size(const T& dim)
......@@ -142,8 +233,8 @@ __host__ __device__ T constexpr size(const T& dim)
* \param layout Layout to get Shape of.
* \return Requsted length.
*/
template <index_t idx, typename Shape, typename Strides>
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
template <index_t idx, typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto size(const Layout<Shape, UnrolledDescriptorType>& layout)
{
return layout.template GetLength<idx>();
}
......@@ -155,7 +246,7 @@ __host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
* \return Requsted size.
*/
template <typename... ShapeDims>
__host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
__host__ __device__ constexpr auto size(const Tuple<ShapeDims...>& shape)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
return TupleReduce<0, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
......@@ -168,8 +259,8 @@ __host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
* \param layout Layout to calculate shape size.
* \return Requsted size.
*/
template <typename Shape, typename Strides>
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto size(const Layout<Shape, UnrolledDescriptorType>& layout)
{
return layout.GetLengths();
}
......@@ -182,7 +273,7 @@ __host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
* \return Requsted length.
*/
template <index_t idx, typename... Ts>
__host__ __device__ constexpr index_t size(const Tuple<Ts...>& tuple)
__host__ __device__ constexpr auto size(const Tuple<Ts...>& tuple)
{
return size(tuple.At(Number<idx>{}));
}
......@@ -208,8 +299,9 @@ __host__ __device__ constexpr auto size(const T& elem)
* \param layout Layout to calculate rank.
* \return Requsted rank.
*/
template <typename Shape, typename Strides>
__host__ __device__ constexpr auto rank([[maybe_unused]] const Layout<Shape, Strides>& layout)
template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto
rank([[maybe_unused]] const Layout<Shape, UnrolledDescriptorType>& layout)
{
return Shape::Size();
}
......@@ -229,17 +321,25 @@ __host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple<Dims...>& t
/**
* \private
* \brief Rank for scalar
*
* \param dim Dimension scalar.
* \return Returned 1.
*/
template <index_t IDim>
__host__ __device__ constexpr index_t rank(const Number<IDim>&)
__host__ __device__ constexpr index_t rank([[maybe_unused]] const Number<IDim>& dim)
{
return 1;
}
/**
* \private
* \brief Rank for scalar
*
* \param dim Dimension scalar.
* \return Returned 1.
*/
__host__ __device__ constexpr index_t rank(const index_t&) { return 1; }
__host__ __device__ constexpr index_t rank([[maybe_unused]] const index_t& dim) { return 1; }
/**
* \brief Hierarchical rank.
......@@ -261,8 +361,8 @@ __host__ __device__ constexpr auto rank(const T& elem)
* \param layout Layout to calculate depth.
* \return Requsted depth.
*/
template <typename Shape, typename Strides>
__host__ __device__ constexpr auto depth(const Layout<Shape, Strides>& layout)
template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto depth(const Layout<Shape, UnrolledDescriptorType>& layout)
{
const auto& shape = layout.GetShape();
return TupleDepth(shape);
......@@ -282,17 +382,25 @@ __host__ __device__ constexpr auto depth(const Tuple<Dims...>& tuple)
/**
* \private
* \brief Depth for scalar
*
* \param dim Scalar.
* \return Returned 0.
*/
template <index_t IDim>
__host__ __device__ constexpr index_t depth(const Number<IDim>&)
__host__ __device__ constexpr index_t depth([[maybe_unused]] const Number<IDim>& dim)
{
return 0;
}
/**
* \private
* \brief Depth for scalar
*
* \param dim Scalar.
* \return Returned 0.
*/
__host__ __device__ constexpr index_t depth(const index_t&) { return 0; }
__host__ __device__ constexpr index_t depth([[maybe_unused]] const index_t& dim) { return 0; }
/**
* \brief Hierarchical depth.
......@@ -307,26 +415,14 @@ __host__ __device__ constexpr auto depth(const T& elem)
return depth(get<Idxs...>(elem));
}
/**
* \brief Get Layout strides.
*
* \param layout Layout to get strides from.
* \return Requsted strides.
*/
template <typename Shape, typename Strides>
__host__ __device__ constexpr const auto& stride(const Layout<Shape, Strides>& layout)
{
return layout.GetStrides();
}
/**
* \brief Get Layout shape.
*
* \param layout Layout to get shape from.
* \return Requsted shape.
*/
template <typename Shape, typename Strides>
__host__ __device__ constexpr const auto& shape(const Layout<Shape, Strides>& layout)
template <typename LayoutType>
__host__ __device__ constexpr const auto& shape(const LayoutType& layout)
{
return layout.GetShape();
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "tensor_utils.hpp"
#include "layout_utils.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
namespace ck {
namespace wrapper {
namespace {
/**
* \brief Calculate shape for partition based on number of threads per each dim and
* previous shape
*
* \param shape Base tensor shape.
* \param thread_lengths Tuple of thread lengths.
* \return Partition shape.
*/
template <typename... Ts, typename... Ls>
__host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts...>& shape,
const Tuple<Ls...>& thread_lengths)
{
static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), "Wrong thread_lengths shape.");
return generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
const auto slice_len = size<num_i>(shape) / thread_lengths.At(num_i);
return slice_len;
},
Number<Tuple<Ls...>::Size()>{});
}
/**
* \brief Calculate total number of blocks.
*
* \param shape Base tensor shape.
* \param tile_shape Tile shape.
* \return Tuple with blocks number.
*/
template <typename... Ts, typename... Ls>
__host__ __device__ constexpr auto CalculateGridSize(const Tuple<Ts...>& shape,
const Tuple<Ls...>& tile_shape)
{
static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), "Wrong thread_lengths shape.");
return generate_tuple([&](auto i) { return size<i>(shape) / size<i>(tile_shape); },
Number<Tuple<Ls...>::Size()>{});
}
/**
* \brief Calculate scaled offset for new partition/tile.
*
* \param thread_idxs Thread 1d id.
* \param partition_lengths_seq Sequence of partition shape.
* \param old_offset_idxs Multi index offset from base tensor to shift values.
* \return Partition shape.
*/
template <typename ThreadIdxs, typename PartitionLengthsSeq, typename OldOffsetIdxs>
__host__ __device__ constexpr auto
CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
const PartitionLengthsSeq& partition_lengths_seq,
const OldOffsetIdxs& old_offset_idxs)
{
return thread_idxs * partition_lengths_seq + old_offset_idxs;
}
} // namespace
/**
* \brief Create local partition for thread (At now only packed partition
* is supported).
*
* \param tensor Tensor for partition.
* \param thread_lengths Layout of threads (could not be nested).
* \param thread_id Thread index represented as integer.
* \return Partition tensor.
*/
template <typename TensorType, typename ThreadLengthsTuple>
__host__ __device__ constexpr auto
make_local_partition(TensorType& tensor,
[[maybe_unused]] const ThreadLengthsTuple& thread_lengths,
const index_t thread_id)
{
static_assert(!IsNestedTuple(ThreadLengthsTuple{}));
// Calculate new partition shape
const auto& tensor_shape = shape(tensor);
constexpr auto partition_shape =
CalculateLocalPartitionShape(decltype(tensor_shape){}, ThreadLengthsTuple{});
// Create Thread Cluster Descriptor
constexpr auto partition_lengths_seq = generate_sequence_v2(
[&](auto I) { return size<I>(partition_shape); }, Number<ThreadLengthsTuple::Size()>{});
constexpr auto thread_lengths_seq =
generate_sequence_v2([&](auto I) { return size<I>(ThreadLengthsTuple{}); },
Number<ThreadLengthsTuple::Size()>{});
constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq);
// Calculate thread idxs and offsets
const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
const auto offset_multi_idxs =
CalculateOffsetMultiIdxs(thread_idxs, partition_lengths_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor
auto& flatten_desc = layout(tensor).GetUnrolledDescriptor();
const auto partition_layout =
Layout<remove_reference_t<decltype(partition_shape)>, decltype(flatten_desc)>(
partition_shape, flatten_desc);
auto partition_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
// Apply offsets
partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
return partition_tensor;
}
/**
* \brief Create local tile for thread block. (At now only packed tile
* is supported).
*
* \note Temporary to gain the best performance use 2d
* tile_shape.
*
*
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_id Block index represented as integer.
* \return Tile tensor.
*/
template <typename TensorType, typename BlockShapeTuple>
__host__ __device__ constexpr auto
make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id)
{
static_assert(!IsNestedTuple(BlockShapeTuple{}));
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor();
if constexpr(BlockShapeTuple::Size() == I2)
{
// Optimized version for 2d tile shape [MxK]
const auto block_2_tile_map =
BlockToCTileMap_M00_N0_M01Adapt<BlockShapeTuple{}.At(I0),
BlockShapeTuple{}.At(I1),
remove_cvref_t<decltype(aligned_desc)>>(aligned_desc);
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * size<0>(tile_shape));
const index_t k_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * size<1>(tile_shape));
const auto offset_multi_idxs =
make_tuple(m_block_data_idx_on_grid, k_block_data_idx_on_grid);
// Create new layout and tensor
const auto tile_layout =
Layout<remove_reference_t<decltype(tile_shape)>, decltype(aligned_desc)>(tile_shape,
aligned_desc);
auto tile_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
// Apply offsets
tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
return tile_tensor;
}
else
{
// Calculate offsets
// Sequence with data to process per block
constexpr auto tile_shape_seq =
generate_sequence_v2([](auto I) { return size(BlockShapeTuple{}.At(I)); },
Number<BlockShapeTuple::Size()>{});
// Tuple with number of blocks
const auto block_lengths = CalculateGridSize(shape(tensor), tile_shape);
constexpr auto block_cluster_desc_ = make_cluster_descriptor(block_lengths);
const auto block_idxs =
block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id));
const auto offset_multi_idxs =
CalculateOffsetMultiIdxs(block_idxs, tile_shape_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor
const auto tile_layout =
Layout<remove_reference_t<decltype(tile_shape)>, decltype(aligned_desc)>(tile_shape,
aligned_desc);
auto tile_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
// Apply offsets
tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
return tile_tensor;
}
}
} // namespace wrapper
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -10,6 +10,7 @@
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/dynamic_buffer.hpp"
#include "ck/utility/amd_address_space.hpp"
#include "ck/utility/multi_index.hpp"
namespace ck {
namespace wrapper {
......@@ -27,16 +28,12 @@ using MemoryTypeEnum = AddressSpaceEnum;
// Disable from doxygen docs generation
/// @cond
// forward declarations
template <typename Shape, typename Strides>
template <typename Shape, typename UnrolledDescriptorType>
struct Layout;
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors, // params for Register memory
index_t ScalarPerVector // param for Register memory
>
typename UnrolledDescriptorType>
struct Tensor;
template <typename FromType, typename ToType>
......@@ -45,13 +42,22 @@ struct Slice
__host__ __device__ constexpr Slice() : from_(), to_() {}
__host__ __device__ constexpr Slice(FromType from, ToType to) : from_(from), to_(to) {}
/**
* \brief Calculate slice range.
*
* \param dim Dimension size.
* \return Slice range.
*/
template <typename T>
__host__ __device__ constexpr auto range(const T& dim) const
{
if constexpr(is_same_v<FromType, index_t> || is_same_v<ToType, index_t> ||
is_same_v<T, index_t>)
{
assert(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_) && "Invalid range");
if(!(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_)))
{
throw std::runtime_error("Invalid range");
}
if(to_ < 0)
{
return dim - from_ + to_ + 1;
......@@ -98,33 +104,30 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
* \param layout Tensor layout.
* \return Constructed tensor.
*/
template <MemoryTypeEnum MemoryType, typename ElementType, typename Shape, typename Strides>
constexpr auto make_tensor(ElementType* pointer, const Layout<Shape, Strides>& layout)
template <MemoryTypeEnum MemoryType,
typename ElementType,
typename Shape,
typename UnrolledDescriptorType>
constexpr auto make_tensor(ElementType* pointer,
const Layout<Shape, UnrolledDescriptorType>& layout)
{
return Tensor<MemoryType, ElementType, Shape, Strides, 0 /*NumVectors*/, 0 /*ScalarPerVector*/>(
pointer, layout);
return Tensor<MemoryType, ElementType, Shape, UnrolledDescriptorType>(pointer, layout);
}
/**
* \brief Make SGPR or VGPR tensor function.
*
* \tparam MemoryType Type of memory.
* \tparam NumVectors Number of vectors.
* \tparam ScalarPerVector Scalars per vector.
* \tparam ElementType Memory data type.
* \param layout Tensor layout.
* \return Constructed tensor.
*/
template <MemoryTypeEnum MemoryType,
index_t NumVectors,
index_t ScalarPerVector,
typename ElementType,
typename Shape,
typename Strides>
constexpr auto make_register_tensor(const Layout<Shape, Strides>& layout)
typename UnrolledDescriptorType>
constexpr auto make_register_tensor(const Layout<Shape, UnrolledDescriptorType>& layout)
{
static_assert(!IsNestedTuple(Shape{}), "Register tensor with nested layout is not supported");
return Tensor<MemoryType, ElementType, Shape, Strides, NumVectors, ScalarPerVector>(layout);
return Tensor<MemoryType, ElementType, Shape, UnrolledDescriptorType>(layout);
}
/**
......@@ -136,12 +139,9 @@ constexpr auto make_register_tensor(const Layout<Shape, Strides>& layout)
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
typename UnrolledDescriptorType>
__host__ __device__ constexpr const auto&
layout(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
layout(const Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
{
return tensor.GetLayout();
}
......@@ -157,12 +157,9 @@ template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr index_t
size(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
typename UnrolledDescriptorType>
__host__ __device__ constexpr auto
size(const Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
{
return size<Idxs...>(tensor.GetLayout());
}
......@@ -178,12 +175,9 @@ template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr index_t
rank(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
typename UnrolledDescriptorType>
__host__ __device__ constexpr auto
rank(const Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
{
return rank<Idxs...>(tensor.GetLayout());
}
......@@ -199,35 +193,13 @@ template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr index_t
depth(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
typename UnrolledDescriptorType>
__host__ __device__ constexpr auto
depth(const Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
{
return depth<Idxs...>(tensor.GetLayout());
}
/**
* \brief Get Tensor strides.
*
* \param tensor Tensor to get strides from.
* \return Requsted strides.
*/
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr const auto&
stride(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
{
return stride(tensor.GetLayout());
}
/**
* \brief Get Tensor shape.
*
......@@ -237,12 +209,9 @@ stride(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors,
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
typename UnrolledDescriptorType>
__host__ __device__ constexpr const auto&
shape(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
shape(const Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
{
return shape(tensor.GetLayout());
}
......
......@@ -265,6 +265,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
return 0;
}
throw std::runtime_error("Col2Img: number of dimensions should be between 1 and 3.");
return 1;
}
float Run(const device::BaseArgument* p_arg,
......
......@@ -313,6 +313,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
return 0;
}
throw std::runtime_error(
"Conv_bwd_data: number of dimensions must be between 1 and 3.");
return 1;
}
float Run(const device::BaseArgument* p_arg,
......
......@@ -265,6 +265,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
return 0;
}
throw std::runtime_error("Conv_bwd: number of dimensions must be between 1 and 3.");
return 1;
}
float Run(const device::BaseArgument* p_arg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment