"vscode:/vscode.git/clone" did not exist on "d6a666fa389b14c7ef1abc0f473d11ffd0c4062c"
Commit 522b7aee authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2

parents ff936fd6 84832fc4
...@@ -35,15 +35,17 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind ...@@ -35,15 +35,17 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
if(lengths.size() != NumDim1 + NumDim2) if(lengths.size() != NumDim1 + NumDim2)
{ {
std::ostringstream err; std::ostringstream err;
err << "Incorrect number of lengths in " << __FILE__ << ":" << __LINE__ err << "Incorrect number of lengths in "
<< ", in function: " << __func__; << "device_contraction_utils.hpp"
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
if(strides.size() != NumDim1 + NumDim2) if(strides.size() != NumDim1 + NumDim2)
{ {
std::ostringstream err; std::ostringstream err;
err << "Incorrect number of strides in " << __FILE__ << ":" << __LINE__ err << "Incorrect number of strides in "
<< ", in function: " << __func__; << "device_contraction_utils.hpp"
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
using DeviceOp = DeviceGemm_Xdl_CShuffleV2;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v2<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
InMemoryDataOperationEnum::Set,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer,
ComputeTypeA,
ComputeTypeB>;
using Argument = typename GridwiseGemm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
float ave_time = 0;
const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1;
if(GridwiseGemm::CalculateKBlockLoopTailNum(K) == 3)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v2<GridwiseGemm, true>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v2<GridwiseGemm, true, 2>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off
str << "DeviceGemm_Xdl_CShuffleV2"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -174,6 +174,11 @@ struct PassThrough ...@@ -174,6 +174,11 @@ struct PassThrough
{ {
y = x; y = x;
} }
template <>
__host__ __device__ void operator()<int4_t, int>(int4_t& y, const int& x) const
{
y = type_convert<int4_t>(x);
}
#endif #endif
template <> template <>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -134,6 +134,11 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void> ...@@ -134,6 +134,11 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
: M_(M), N_(N), M01_(M01) : M_(M), N_(N), M01_(M01)
{ {
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
}
#endif
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
...@@ -252,6 +257,302 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo ...@@ -252,6 +257,302 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo
BlockToCTileMap_M00_N0_M01Adapt; BlockToCTileMap_M00_N0_M01Adapt;
}; };
// Rows of column-vectors
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt;
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock>
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default;
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(
const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default;
__host__ __device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default;
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt&
operator=(const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default;
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt&
operator=(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default;
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M,
index_t N,
index_t M01 = 8)
: M_(M), N_(N), M01_(M01)
{
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
}
#endif
}
template <typename CGridDesc_M_N>
__host__ __device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01 = 8)
: BlockToCTileMap_Grouped_M00_N0_M01Adapt(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
{
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return M0 * N0;
}
template <typename CGridDesc_M_N>
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum);
auto group_id = block_1d_id % GroupNum;
auto remap_block_1d_id = group_id * group_size + block_1d_id / GroupNum;
index_t idx_N0 = remap_block_1d_id % N0;
index_t idx_M0 = remap_block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
index_t idx_M00 = idx_M0 / M01_;
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t M_;
index_t N_;
index_t M01_;
};
// keep the redundant type argument for backward compatibility
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
: BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>
{
using BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>::
BlockToCTileMap_Grouped_M00_N0_M01Adapt;
};
// columns of row-vectors
// This C-tile map dynamically adjusts N01 when C-tile index is out of range
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
struct BlockToCTileMap_N00_M0_N01Adapt;
template <index_t MPerBlock, index_t NPerBlock>
struct BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock, void>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt() = default;
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const BlockToCTileMap_N00_M0_N01Adapt&) =
default;
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(BlockToCTileMap_N00_M0_N01Adapt&&) =
default;
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt&
operator=(const BlockToCTileMap_N00_M0_N01Adapt&) = default;
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt&
operator=(BlockToCTileMap_N00_M0_N01Adapt&&) = default;
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(index_t M, index_t N, index_t N01 = 8)
: M_(M), N_(N), N01_(N01)
{
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, N01 = %d\n", M_, N_, N01_);
}
#endif
}
template <typename CGridDesc_M_N>
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
index_t N01 = 8)
: BlockToCTileMap_N00_M0_N01Adapt(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), N01)
{
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return M0 * N0;
}
template <typename CGridDesc_M_N>
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
index_t idx_M0 = block_1d_id % M0;
index_t idx_N0 = block_1d_id / M0;
const auto N01_adapt = (idx_N0 < N0 - N0 % N01_) ? N01_ : N0 % N01_;
index_t idx_N00 = idx_N0 / N01_;
index_t idx_N01 = idx_N0 % N01_;
index_t idx_M0_N01_local = idx_M0 + idx_N01 * M0;
/**
* idxN0
*
* |< mtx N >|
*
* |<---N01--->|
* - |-----------|-----------|-----------|-----|-----|-
* ^ | 0 ----------> 1 | | | |
* | | / | | | | M_0 MPerBlock
* | / | | | |
* |------/----------------|-----------|-----|-----|-
* | | | | | | |
* idxM0 | V | | | | | M_1 MPerBlock
* | 2 ----------> 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | blockid | | | |
* | | 5 | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* Example:
* assume:
* N0 = 5
* M0 = 4
* block_1d_id = 5
* N01 = 2
*
* idx_M0 = 1
* idx_N0 = 1
* N01_adapt = 2
* idx_N00 = 0
* idx_N01 = 1
* idx_M0_N01_local = 5
* output {2, 1}
*/
return make_tuple(idx_M0_N01_local / N01_adapt,
idx_M0_N01_local % N01_adapt + idx_N00 * N01_);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t M_;
index_t N_;
index_t N01_;
};
// 2D slices of column-vectors in 3D space // 2D slices of column-vectors in 3D space
// This C-tile map dynamically adjusts M01 when C-tile index is out of range // This C-tile map dynamically adjusts M01 when C-tile index is out of range
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N> template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
......
...@@ -119,7 +119,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk ...@@ -119,7 +119,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
const InDataTypePointerTuple p_in_global_tuple, const InDataTypePointerTuple p_in_global_tuple,
XDataType* const __restrict__ p_x_lds, XDataType* const __restrict__ p_x_lds_,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
...@@ -149,7 +149,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk ...@@ -149,7 +149,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto x_lds_val_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto x_lds_val_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_x_lds, x_grid_desc_m_k.GetElementSpaceSize() / grid_size); p_x_lds_, x_grid_desc_m_k.GetElementSpaceSize() / grid_size);
auto in_thread_buf_tuple = generate_tuple( auto in_thread_buf_tuple = generate_tuple(
[&](auto) { [&](auto) {
......
...@@ -268,6 +268,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -268,6 +268,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
} }
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding)
{
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else else
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -329,6 +344,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -329,6 +344,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
} }
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding)
{
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else else
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
......
...@@ -328,7 +328,7 @@ struct WmmaSelector ...@@ -328,7 +328,7 @@ struct WmmaSelector
} }
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <> template <>
static constexpr auto GetWmma<int4_t, int, 16, 16>() static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
{ {
return WmmaInstr::wmma_i32_16x16x16_iu4; return WmmaInstr::wmma_i32_16x16x16_iu4;
} }
......
...@@ -189,6 +189,7 @@ struct vector_type<T, 1> ...@@ -189,6 +189,7 @@ struct vector_type<T, 1>
} }
}; };
int static err = 0;
template <typename T> template <typename T>
struct vector_type<T, 2> struct vector_type<T, 2>
{ {
...@@ -221,6 +222,10 @@ struct vector_type<T, 2> ...@@ -221,6 +222,10 @@ struct vector_type<T, 2>
{ {
return data_.d2x1_; return data_.d2x1_;
} }
else
{
return err;
}
} }
template <typename X> template <typename X>
...@@ -236,6 +241,10 @@ struct vector_type<T, 2> ...@@ -236,6 +241,10 @@ struct vector_type<T, 2>
{ {
return data_.d2x1_; return data_.d2x1_;
} }
else
{
return err;
}
} }
}; };
...@@ -278,6 +287,10 @@ struct vector_type<T, 4> ...@@ -278,6 +287,10 @@ struct vector_type<T, 4>
{ {
return data_.d4x1_; return data_.d4x1_;
} }
else
{
return err;
}
} }
template <typename X> template <typename X>
...@@ -298,6 +311,10 @@ struct vector_type<T, 4> ...@@ -298,6 +311,10 @@ struct vector_type<T, 4>
{ {
return data_.d4x1_; return data_.d4x1_;
} }
else
{
return err;
}
} }
}; };
...@@ -347,6 +364,10 @@ struct vector_type<T, 8> ...@@ -347,6 +364,10 @@ struct vector_type<T, 8>
{ {
return data_.d8x1_; return data_.d8x1_;
} }
else
{
return err;
}
} }
template <typename X> template <typename X>
...@@ -372,6 +393,10 @@ struct vector_type<T, 8> ...@@ -372,6 +393,10 @@ struct vector_type<T, 8>
{ {
return data_.d8x1_; return data_.d8x1_;
} }
else
{
return err;
}
} }
}; };
...@@ -428,6 +453,10 @@ struct vector_type<T, 16> ...@@ -428,6 +453,10 @@ struct vector_type<T, 16>
{ {
return data_.d16x1_; return data_.d16x1_;
} }
else
{
return err;
}
} }
template <typename X> template <typename X>
...@@ -458,6 +487,10 @@ struct vector_type<T, 16> ...@@ -458,6 +487,10 @@ struct vector_type<T, 16>
{ {
return data_.d16x1_; return data_.d16x1_;
} }
else
{
return err;
}
} }
}; };
...@@ -520,6 +553,10 @@ struct vector_type<T, 32> ...@@ -520,6 +553,10 @@ struct vector_type<T, 32>
{ {
return data_.d32x1_; return data_.d32x1_;
} }
else
{
return err;
}
} }
template <typename X> template <typename X>
...@@ -554,6 +591,10 @@ struct vector_type<T, 32> ...@@ -554,6 +591,10 @@ struct vector_type<T, 32>
{ {
return data_.d32x1_; return data_.d32x1_;
} }
else
{
return err;
}
} }
}; };
...@@ -623,6 +664,10 @@ struct vector_type<T, 64> ...@@ -623,6 +664,10 @@ struct vector_type<T, 64>
{ {
return data_.d64x1_; return data_.d64x1_;
} }
else
{
return err;
}
} }
template <typename X> template <typename X>
...@@ -662,6 +707,10 @@ struct vector_type<T, 64> ...@@ -662,6 +707,10 @@ struct vector_type<T, 64>
{ {
return data_.d64x1_; return data_.d64x1_;
} }
else
{
return err;
}
} }
}; };
...@@ -737,6 +786,10 @@ struct vector_type<T, 128> ...@@ -737,6 +786,10 @@ struct vector_type<T, 128>
{ {
return data_.d128x1_; return data_.d128x1_;
} }
else
{
return err;
}
} }
template <typename X> template <typename X>
...@@ -780,6 +833,10 @@ struct vector_type<T, 128> ...@@ -780,6 +833,10 @@ struct vector_type<T, 128>
{ {
return data_.d128x1_; return data_.d128x1_;
} }
else
{
return err;
}
} }
}; };
...@@ -861,6 +918,10 @@ struct vector_type<T, 256> ...@@ -861,6 +918,10 @@ struct vector_type<T, 256>
{ {
return data_.d256x1_; return data_.d256x1_;
} }
else
{
return err;
}
} }
template <typename X> template <typename X>
...@@ -908,6 +969,10 @@ struct vector_type<T, 256> ...@@ -908,6 +969,10 @@ struct vector_type<T, 256>
{ {
return data_.d256x1_; return data_.d256x1_;
} }
else
{
return err;
}
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -19,6 +19,12 @@ struct is_known_at_compile_time<index_t> ...@@ -19,6 +19,12 @@ struct is_known_at_compile_time<index_t>
static constexpr bool value = false; static constexpr bool value = false;
}; };
template <>
struct is_known_at_compile_time<unsigned int>
{
static constexpr bool value = false;
};
template <> template <>
struct is_known_at_compile_time<long_index_t> struct is_known_at_compile_time<long_index_t>
{ {
......
...@@ -178,4 +178,15 @@ __host__ __device__ constexpr auto TupleDepth(const Tuple<Ts...>&) ...@@ -178,4 +178,15 @@ __host__ __device__ constexpr auto TupleDepth(const Tuple<Ts...>&)
return math::max(TupleDepth<depth + 1>(Ts{})...); return math::max(TupleDepth<depth + 1>(Ts{})...);
} }
template <index_t from, index_t to, typename... Ts>
__host__ __device__ constexpr auto TupleSlice(const Tuple<Ts...>& tuple)
{
return generate_tuple(
[&](auto i) {
using Idx = Number<from + i>;
return tuple.At(Idx{});
},
Number<to - from>{});
}
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,24 +14,28 @@ namespace wrapper { ...@@ -14,24 +14,28 @@ namespace wrapper {
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t * \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). It is possible to pass nested shapes * (dynamic layout). It is possible to pass nested shapes
* (e.g. ((4, 2), 2)), nested dimensions are merged. * (e.g. ((4, 2), 2)), nested dimensions are merged.
* \tparam Strides Tuple of Number<> (for compile-time layout) or index_t * \tparam UnrolledDescriptorType Tensor descriptor for unnested shape dims.
* (dynamic layout). Stride tuple should be nested if shape tuple is
* nested.
*/ */
template <typename Shape, typename Strides> template <typename Shape, typename UnrolledDescriptorType>
struct Layout struct Layout
{ {
private: private:
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
// Generate default idxs tuple (idx with all merged nested shapes) /**
* \brief Generate default indices tuple (idx with all merged nested shapes)
*
* \param shape Shape to align.
* \return Multi idx tuple with zeros.
*/
template <typename... Ts> template <typename... Ts>
__host__ __device__ constexpr static auto GenerateDefaultIdxsTuple(const Tuple<Ts...>&) __host__ __device__ constexpr static auto
GenerateDefaultIdxsTuple([[maybe_unused]] const Tuple<Ts...>& shape)
{ {
return generate_tuple( return generate_tuple(
[&](auto) { [&](auto) {
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime()) if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
{ {
// runtime layout // runtime layout
return index_t(0); return index_t(0);
...@@ -45,32 +49,18 @@ struct Layout ...@@ -45,32 +49,18 @@ struct Layout
Number<Tuple<Ts...>::Size()>{}); Number<Tuple<Ts...>::Size()>{});
} }
// Generate packed (column-major) strides if not passed /**
template <typename... Ts> * \brief Generate lower dims in compile-time for the Merge transform using
__host__ __device__ constexpr static auto * provided type. If element of nested Tuple<Ts...> is also a tuple, then
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape) * merge (generate sequence for merge). If tuple is element, then pass
{ * through (sequence with one element).
const auto unrolled_shape = UnrollNestedTuple(shape); *
return generate_tuple( * \param shape Shape to align.
[&](auto i) { * \return LowerDims for MergeTrasform.
if constexpr(i.value == 0) */
{
return I1;
}
else
{
return TupleReduce<I0.value, i.value>([](auto x, auto y) { return x * y; },
unrolled_shape);
}
},
Number<decltype(unrolled_shape)::Size()>{});
}
// Generate LowerDims in Compile-time for MergeTrasform using passed Type
// If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
// If tuple is element, then pass through (sequence with one element)
template <typename Idx, typename... Ts> template <typename Idx, typename... Ts>
__host__ __device__ constexpr static auto GenerateLowerDim(const Tuple<Ts...>&) __host__ __device__ constexpr static auto
GenerateLowerDim([[maybe_unused]] const Tuple<Ts...>& shape)
{ {
if constexpr(Idx::value == 0) if constexpr(Idx::value == 0)
{ {
...@@ -110,11 +100,17 @@ struct Layout ...@@ -110,11 +100,17 @@ struct Layout
} }
} }
// Iterate over nested tuples in shape /**
// Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...> * \brief Iterate over the nested tuples in the shape.
// Example idx: (1, 1), 1, 1 * Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
// Example shape: (2, (2, 2)), 2, (2, 2) * Example idx: (1, 1), 1, 1
// Unrolled shape: 2, (2, 2), 2, (2, 2) * Example shape: (2, (2, 2)), 2, (2, 2)
* Unrolled shape: 2, (2, 2), 2, (2, 2)
*
* \param shape Layout shape.
* \param idx Idx to align.
* \return Algined shape.
*/
template <typename... ShapeDims, typename... IdxDims> template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple<ShapeDims...>& shape, __host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx) const Tuple<IdxDims...>& idx)
...@@ -149,6 +145,13 @@ struct Layout ...@@ -149,6 +145,13 @@ struct Layout
} }
} }
/**
* \brief Merge descriptor to 1D.
*
* \param shape Layout shape.
* \param desc Descriptor to merge.
* \return 1D descriptor.
*/
template <typename... ShapeDims, typename DescriptorToMerge> template <typename... ShapeDims, typename DescriptorToMerge>
__host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape, __host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape,
const DescriptorToMerge& desc) const DescriptorToMerge& desc)
...@@ -160,18 +163,41 @@ struct Layout ...@@ -160,18 +163,41 @@ struct Layout
const auto lower_dims = make_tuple(MergeElemsSequence::Reverse()); const auto lower_dims = make_tuple(MergeElemsSequence::Reverse());
const auto upper_dims = make_tuple(Sequence<0>{}); const auto upper_dims = make_tuple(Sequence<0>{});
// Merge to 1d // Merge to 1d
return transform_tensor_descriptor( if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims); {
return transform_tensor_descriptor(
desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
}
else
{
// If the descriptor is known at the compilation time,
// use `make_merge_transform_v1_carry_check` because it doesn't use
// memcpy.
return transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform_v1_carry_check(merge_elems)),
lower_dims,
upper_dims);
}
} }
// Merge nested shape dims when corresponding index is also nested. /**
// Input desc shape: 2, 2, 2, 2, 2, 2 * \brief Merge nested shape dims when corresponding index is also merged.
// Example idx: 1, 1, 1, 1 * Input desc shape: 2, 2, 2, 2, 2, 2
// Example shape: 2, (2, 2), 2, (2, 2) * Example idx: 1, 1, 1, (1, 1)
// Merged shape: 2, 4, 2, 4 * Example shape: 2, (2, 2), 2, (2, 2)
* Merged shape: 2, 4, 2, 2, 2
*
* \param shape Layout shape.
* \param idxs Indexes to align descriptor.
* \param desc Descriptor to merge.
* \return Aligned descriptor to idx.
*/
template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge> template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge>
__host__ __device__ constexpr static auto CreateMergedDescriptor( __host__ __device__ constexpr static auto
const Tuple<ShapeDims...>& shape, const Tuple<IdxDims...>&, DescriptorToMerge& desc) CreateMergedDescriptor(const Tuple<ShapeDims...>& shape,
[[maybe_unused]] const Tuple<IdxDims...>& idxs,
DescriptorToMerge& desc)
{ {
const auto transforms = generate_tuple( const auto transforms = generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -183,7 +209,17 @@ struct Layout ...@@ -183,7 +209,17 @@ struct Layout
// If shape element is tuple and idx element is Number, then merge // If shape element is tuple and idx element is Number, then merge
// Unroll and reverse tuple to traverse column-major // Unroll and reverse tuple to traverse column-major
const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i))); const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i)));
return make_merge_transform(merge_elems); if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
{
return make_merge_transform(merge_elems);
}
else
{
// If the descriptor is known at the compilation time,
// use `make_merge_transform_v1_carry_check` because
// it doesn't use memcpy.
return make_merge_transform_v1_carry_check(merge_elems);
}
} }
else else
{ {
...@@ -207,33 +243,24 @@ struct Layout ...@@ -207,33 +243,24 @@ struct Layout
return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims); return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
} }
template <typename LayoutShape, typename LayoutStrides>
__host__ __device__ static auto MakeFlattenDescriptor(const LayoutShape& shape,
const LayoutStrides& strides)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
const auto unrolled_strides = UnrollNestedTuple(strides);
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
"Size of strides and shape are not consistent.");
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
}
// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
using DeducedStrides =
std::conditional_t<is_same_v<Strides, Tuple<>>,
remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
Strides>;
using FlattenDescriptorType =
remove_cvref_t<decltype(MakeFlattenDescriptor(Shape{}, DeducedStrides{}))>;
using Descriptor1dType = using Descriptor1dType =
remove_cvref_t<decltype(MakeMerge1d(Shape{}, FlattenDescriptorType{}))>; remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnrolledDescriptorType{}))>;
using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>; using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
public:
/**
* \brief Transform descriptor to align to passed indexes.
*
* \param shape Layout shape.
* \param idxs Indexes to align descriptor.
* \param naive_descriptor Descriptor to merge.
* \return Aligned descriptor to idx.
*/
template <typename... ShapeDims, typename... IdxDims> template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr static auto __host__ __device__ constexpr static auto
TransformDesc(const Tuple<ShapeDims...>& shape, TransformDesc(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx, const Tuple<IdxDims...>& idxs,
const FlattenDescriptorType& naive_descriptor) const UnrolledDescriptorType& naive_descriptor)
{ {
if constexpr(Tuple<IdxDims...>::Size() == I1) if constexpr(Tuple<IdxDims...>::Size() == I1)
{ {
...@@ -249,55 +276,38 @@ struct Layout ...@@ -249,55 +276,38 @@ struct Layout
static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(), static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
"Idx rank and Shape rank must be the same (except 1d)."); "Idx rank and Shape rank must be the same (except 1d).");
// Unroll while IdxDims is nested // Unroll while IdxDims is nested
const auto aligned_shape = AlignShapeToIdx(shape, idx); const auto aligned_shape = AlignShapeToIdx(shape, idxs);
// Transform correct form of shape // Transform correct form of shape
return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), naive_descriptor); return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idxs), naive_descriptor);
} }
} }
using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc( using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc(
Shape{}, DefaultIdxsTupleType{}, FlattenDescriptorType{}))>; Shape{}, DefaultIdxsTupleType{}, UnrolledDescriptorType{}))>;
public:
__host__ __device__ constexpr auto GetElementSpaceSize() const __host__ __device__ constexpr auto GetElementSpaceSize() const
{ {
return flatten_descriptor_.GetElementSpaceSize(); return unrolled_descriptor_.GetElementSpaceSize();
} }
__host__ __device__ Layout() = delete; __host__ __device__ Layout() = delete;
/** /**
* \brief Layout constructor. * \brief Layout constructor.
* *
* \param shape Shape for layout. * \param shape Shape for layout.
* \param strides Strides for layout (optional if tensor is packed). * \param unnested_descriptor Descriptor
*/ */
__host__ __device__ constexpr Layout(const Shape& shape, const Strides& strides) __host__ __device__ constexpr Layout(const Shape& shape,
: flatten_descriptor_{}, shape_(shape), strides_(strides) const UnrolledDescriptorType& unnested_descriptor)
: unrolled_descriptor_(unnested_descriptor), shape_(shape)
{ {
// Construct if runtime mode // Construct if runtime mode
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime()) if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
{
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
merged_nests_descriptor_ =
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
}
}
/**
* \brief Layout constructor (with default packed column-major strides).
*
* \param shape Shape for layout.
*/
__host__ __device__ constexpr Layout(const Shape& shape)
: flatten_descriptor_{}, shape_(shape), strides_(GenerateColumnMajorPackedStrides(shape_))
{
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
{ {
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_); descriptor_1d_ = MakeMerge1d(shape_, unrolled_descriptor_);
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
merged_nests_descriptor_ = merged_nests_descriptor_ =
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_); TransformDesc(shape_, DefaultIdxsTupleType{}, unrolled_descriptor_);
} }
} }
...@@ -310,9 +320,9 @@ struct Layout ...@@ -310,9 +320,9 @@ struct Layout
template <typename Idxs> template <typename Idxs>
__host__ __device__ constexpr index_t operator()() const __host__ __device__ constexpr index_t operator()() const
{ {
static_assert(FlattenDescriptorType::IsKnownAtCompileTime(), static_assert(remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime(),
"Compiletime operator used on runtime layout."); "Compiletime operator used on runtime layout.");
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, FlattenDescriptorType{})); using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnrolledDescriptorType{}));
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{})); using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
return TransformedDesc{}.CalculateOffset(UnrolledIdx{}); return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
} }
...@@ -339,7 +349,7 @@ struct Layout ...@@ -339,7 +349,7 @@ struct Layout
else else
{ {
// Custom index, need to transform descriptor // Custom index, need to transform descriptor
const auto transformed_desc = TransformDesc(shape_, Idx, flatten_descriptor_); const auto transformed_desc = TransformDesc(shape_, Idx, unrolled_descriptor_);
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx)); return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
} }
} }
...@@ -351,7 +361,7 @@ struct Layout ...@@ -351,7 +361,7 @@ struct Layout
* \return Calculated size. * \return Calculated size.
*/ */
template <index_t IDim> template <index_t IDim>
__host__ __device__ constexpr index_t GetLength() const __host__ __device__ constexpr auto GetLength() const
{ {
const auto elem = shape_.At(Number<IDim>{}); const auto elem = shape_.At(Number<IDim>{});
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value) if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
...@@ -371,7 +381,7 @@ struct Layout ...@@ -371,7 +381,7 @@ struct Layout
* *
* \return Calculated size. * \return Calculated size.
*/ */
__host__ __device__ constexpr index_t GetLengths() const __host__ __device__ constexpr auto GetLengths() const
{ {
const auto unrolled_shape = UnrollNestedTuple(shape_); const auto unrolled_shape = UnrollNestedTuple(shape_);
return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; }, return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
...@@ -385,13 +395,6 @@ struct Layout ...@@ -385,13 +395,6 @@ struct Layout
*/ */
__host__ __device__ constexpr const Shape& GetShape() const { return shape_; } __host__ __device__ constexpr const Shape& GetShape() const { return shape_; }
/**
* \brief Strides getter.
*
* \return Strides.
*/
__host__ __device__ constexpr const DeducedStrides& GetStrides() const { return strides_; }
/** /**
* \brief Get default lengths (tuple filled with Shape length elements). * \brief Get default lengths (tuple filled with Shape length elements).
* *
...@@ -413,21 +416,56 @@ struct Layout ...@@ -413,21 +416,56 @@ struct Layout
} }
/** /**
* \brief Get default descriptor (with the same size as Shape) * \brief Get descriptor with all nested dimensions merged.
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (4, 2)
*
* \note The size of merged descriptor is the same as Layout's shape.
* *
* \return Default descriptor. * \return Merged nests descriptor.
*/ */
__host__ __device__ constexpr MergedNestsDescriptorType GetDefaultDescriptor() __host__ __device__ constexpr const MergedNestsDescriptorType&
GetMergedNestingDescriptor() const
{ {
return merged_nests_descriptor_; return merged_nests_descriptor_;
} }
/**
* \brief Get descriptor with all dimensions are merged (1D).
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (8)
*
* \return 1D descriptor.
*/
__host__ __device__ constexpr const Descriptor1dType& Get1DDescriptor() const
{
return descriptor_1d_;
}
/**
* \brief Get unnested descriptor (with unrolled dims)
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (2, 2, 2)
*
* \return Flattened descriptor.
*/
__host__ __device__ constexpr const UnrolledDescriptorType& GetUnrolledDescriptor() const
{
return unrolled_descriptor_;
}
private: private:
FlattenDescriptorType flatten_descriptor_; // All dimensions are unrolled
UnrolledDescriptorType unrolled_descriptor_;
// 1D descriptor
Descriptor1dType descriptor_1d_; Descriptor1dType descriptor_1d_;
// All nesting are merged
MergedNestsDescriptorType merged_nests_descriptor_; MergedNestsDescriptorType merged_nests_descriptor_;
// Example, shape: ((2, 2), 2)
// UnrolledDescriptorType lengths: (2, 2, 2)
// Descriptor1dType lengths: (8)
// MergedNestsDescriptorType lengths: (4, 2)
const Shape shape_; const Shape shape_;
const DeducedStrides strides_;
}; };
} // namespace wrapper } // namespace wrapper
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "../utils/tensor_utils.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace wrapper {
/**
* \brief Perform generic copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
*
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template <typename SrcTensorType, typename DstTensorType>
__host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
{
if constexpr(!SrcTensorType::IsDynamicBuffer)
{
using SizeType = decltype(size(src_tensor));
static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
}
else if constexpr(!DstTensorType::IsDynamicBuffer)
{
using SizeType = decltype(size(dst_tensor));
static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
}
else
{
for(int i = 0; i < size(src_tensor); i++)
{
dst_tensor(i) = src_tensor(i);
}
}
}
/**
* \brief Perform optimized copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
*
* \tparam DimAccessOrderTuple Tuple with dimension access order.
* \tparam VectorDim Dimension for vectorized read and write.
* \tparam ScalarPerVector Number of scalar per vectorized read and write.
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template <typename DimAccessOrderTuple,
index_t VectorDim,
index_t ScalarPerVector,
typename SrcTensorType,
typename DstTensorType>
__device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
{
static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
const auto& in_grid_desc = layout(src_tensor).GetUnrolledDescriptor();
const auto& out_grid_desc = layout(dst_tensor).GetUnrolledDescriptor();
using SrcShapeType = remove_cvref_t<decltype(shape(src_tensor))>;
constexpr index_t num_dims = SrcShapeType::Size();
constexpr auto thread_slice_lengths =
generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
constexpr auto dim_access_order = generate_sequence_v2(
[](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});
if constexpr(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer)
{
// Perform a copy between DynamicBuffers
auto transfer = ThreadwiseTensorSliceTransfer_v7<
Tuple<typename SrcTensorType::TensorElementType>,
Tuple<typename DstTensorType::TensorElementType>,
decltype(tie(in_grid_desc)),
decltype(tie(out_grid_desc)),
tensor_operation::element_wise::PassThrough,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
decltype(thread_slice_lengths),
decltype(dim_access_order),
VectorDim,
ScalarPerVector,
Sequence<false>,
Sequence<false>>{in_grid_desc,
make_tuple(src_tensor.GetMultiIdxOffsets()),
out_grid_desc,
make_tuple(dst_tensor.GetMultiIdxOffsets()),
tensor_operation::element_wise::PassThrough{}};
transfer.Run(tie(in_grid_desc),
tie(src_tensor.GetBuffer()),
tie(out_grid_desc),
tie(dst_tensor.GetBuffer()));
}
else if constexpr(!SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer)
{
// Perform copy from StaticBuffer to DynamicBuffer
const auto src_slice_origin_idxs =
generate_tuple([&](auto) { return I0; }, Number<num_dims>{});
auto transfer =
ThreadwiseTensorSliceTransfer_v1r3<typename SrcTensorType::TensorElementType,
typename DstTensorType::TensorElementType,
remove_cvref_t<decltype(in_grid_desc)>,
remove_cvref_t<decltype(out_grid_desc)>,
tensor_operation::element_wise::PassThrough,
decltype(thread_slice_lengths),
decltype(dim_access_order),
VectorDim,
ScalarPerVector,
InMemoryDataOperationEnum::Set,
I1,
true>{out_grid_desc,
dst_tensor.GetMultiIdxOffsets(),
tensor_operation::element_wise::PassThrough{}};
transfer.Run(in_grid_desc,
src_slice_origin_idxs,
src_tensor.GetBuffer(),
out_grid_desc,
dst_tensor.GetBuffer());
}
else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer)
{
// Perform copy from DynamicBuffer to StaticBuffer
const auto src_dst_slice_origin =
generate_tuple([&](auto) { return I0; }, Number<num_dims>{});
constexpr auto src_vector_tensor_lengths = generate_sequence_v2(
[&](auto I) {
if constexpr(I == VectorDim)
{
return Number<ScalarPerVector>{};
}
else
{
return I1;
}
},
Number<num_dims>{});
auto transfer =
ThreadwiseTensorSliceTransfer_v4r1<typename SrcTensorType::TensorElementType,
typename DstTensorType::TensorElementType,
remove_cvref_t<decltype(in_grid_desc)>,
remove_cvref_t<decltype(out_grid_desc)>,
decltype(thread_slice_lengths),
decltype(dim_access_order),
decltype(src_vector_tensor_lengths),
decltype(dim_access_order)>{
src_tensor.GetMultiIdxOffsets()};
transfer.Run(in_grid_desc,
src_dst_slice_origin,
src_tensor.GetBuffer(),
out_grid_desc,
src_dst_slice_origin,
dst_tensor.GetBuffer());
}
else
{
// Perform copy between StaticBuffers
copy(src_tensor, dst_tensor);
}
}
} // namespace wrapper
} // namespace ck
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -265,6 +265,8 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -265,6 +265,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
return 0; return 0;
} }
throw std::runtime_error("Col2Img: number of dimensions should be between 1 and 3.");
return 1;
} }
float Run(const device::BaseArgument* p_arg, float Run(const device::BaseArgument* p_arg,
......
...@@ -313,6 +313,9 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -313,6 +313,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
return 0; return 0;
} }
throw std::runtime_error(
"Conv_bwd_data: number of dimensions must be between 1 and 3.");
return 1;
} }
float Run(const device::BaseArgument* p_arg, float Run(const device::BaseArgument* p_arg,
......
...@@ -265,6 +265,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -265,6 +265,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
return 0; return 0;
} }
throw std::runtime_error("Conv_bwd: number of dimensions must be between 1 and 3.");
return 1;
} }
float Run(const device::BaseArgument* p_arg, float Run(const device::BaseArgument* p_arg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment