Commit f23a2e2a authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

resolved conflicts

parents f3eb5a18 c0adab48
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -57,6 +59,18 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -57,6 +59,18 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using BLayout = typename Base::BLayout; using BLayout = typename Base::BLayout;
using CLayout = typename Base::CLayout; using CLayout = typename Base::CLayout;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
using P_ = GemmPipeline;
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>,
concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
// clang-format on
}
struct BatchedGemmKernelArgs : GemmKernelArgs struct BatchedGemmKernelArgs : GemmKernelArgs
{ {
index_t batch_stride_A; index_t batch_stride_A;
...@@ -70,7 +84,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -70,7 +84,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
__host__ static constexpr auto __host__ static constexpr auto
GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count) GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count)
{ {
return TilePartitioner::GridSize(M, N, KBatch * batch_count); return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch);
} }
__host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); }
...@@ -101,14 +115,14 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -101,14 +115,14 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
{ {
const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y); const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z / kargs.KBatch); const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.y);
const auto i_k = __builtin_amdgcn_readfirstlane(blockIdx.z - i_batch * kargs.KBatch); const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_k); const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_splitk);
// options // options
const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A); const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
...@@ -128,7 +142,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -128,7 +142,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
if(kargs.KBatch == 1) if(kargs.k_batch == 1)
{ {
this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
} }
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -75,12 +76,19 @@ struct GemmKernel ...@@ -75,12 +76,19 @@ struct GemmKernel
static constexpr auto I1 = number<1>(); static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>(); static constexpr auto I2 = number<2>();
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) [[nodiscard]] CK_TILE_HOST static const std::string GetName()
{ {
return TilePartitioner::GridSize(M, N, KBatch); // clang-format off
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
// clang-format on
} }
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
struct GemmKernelArgs struct GemmKernelArgs
{ {
...@@ -93,7 +101,7 @@ struct GemmKernel ...@@ -93,7 +101,7 @@ struct GemmKernel
index_t stride_A; index_t stride_A;
index_t stride_B; index_t stride_B;
index_t stride_C; index_t stride_C;
index_t KBatch; index_t k_batch;
}; };
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs) CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
...@@ -121,7 +129,7 @@ struct GemmKernel ...@@ -121,7 +129,7 @@ struct GemmKernel
const std::size_t k_id = blockIdx.z) const std::size_t k_id = blockIdx.z)
{ {
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = kargs.KBatch * K1; const index_t K_t = kargs.k_batch * K1;
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
...@@ -142,13 +150,13 @@ struct GemmKernel ...@@ -142,13 +150,13 @@ struct GemmKernel
b_k_split_offset = k_id * KRead; b_k_split_offset = k_id * KRead;
} }
if(k_id < static_cast<uint32_t>(kargs.KBatch - 1)) if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
{ {
splitted_k = KRead; splitted_k = KRead;
} }
else else
{ {
splitted_k = kargs.K - KRead * (kargs.KBatch - 1); splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
} }
} }
...@@ -159,14 +167,10 @@ struct GemmKernel ...@@ -159,14 +167,10 @@ struct GemmKernel
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{ {
constexpr bool is_output_c_reg_transposed = if constexpr(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC(); is_any_of<CDataType, fp16_t, bf16_t>::value)
if constexpr(!((GemmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed) ||
!(std::is_same_v<CDataType, fp16_t> || std::is_same_v<CDataType, bf16_t>)))
{ {
if(kargs.KBatch != 1) if(kargs.k_batch != 1)
{ {
std::cerr << "Conditions not met for Kbatch >1 !" << std::endl; std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
return false; return false;
...@@ -182,7 +186,7 @@ struct GemmKernel ...@@ -182,7 +186,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.K % GemmPipeline::VectorSizeA != 0) if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
{ {
std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl; std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
return false; return false;
...@@ -197,7 +201,7 @@ struct GemmKernel ...@@ -197,7 +201,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.M % GemmPipeline::VectorSizeA != 0) if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
{ {
std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl; std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
return false; return false;
...@@ -213,7 +217,7 @@ struct GemmKernel ...@@ -213,7 +217,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.N % GemmPipeline::VectorSizeB != 0) if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
{ {
std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl; std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
return false; return false;
...@@ -228,7 +232,7 @@ struct GemmKernel ...@@ -228,7 +232,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.K % GemmPipeline::VectorSizeB != 0) if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
{ {
std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl; std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
return false; return false;
...@@ -244,7 +248,7 @@ struct GemmKernel ...@@ -244,7 +248,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.N % GemmPipeline::VectorSizeC != 0) if(kargs.N % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
{ {
std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl; std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
return false; return false;
...@@ -259,7 +263,7 @@ struct GemmKernel ...@@ -259,7 +263,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.M % GemmPipeline::VectorSizeC != 0) if(kargs.M % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
{ {
std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl; std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
return false; return false;
...@@ -275,14 +279,6 @@ struct GemmKernel ...@@ -275,14 +279,6 @@ struct GemmKernel
const GemmKernelArgs& kargs, const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset) const SplitKBatchOffset& splitk_batch_offset)
{ {
// const auto idxs = TilePartitioner{}();
// const auto i_m = idxs.at(number<0>{});
// const auto i_n = idxs.at(number<1>{});
// // options
// const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
// const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// // Convert pointers to tensor views
// auto a_tensor_view = [&]() {
const auto& a_tensor_view = [&]() { const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
...@@ -290,7 +286,7 @@ struct GemmKernel ...@@ -290,7 +286,7 @@ struct GemmKernel
a_ptr, a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k), make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1), make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{}, number<GemmPipeline::GetVectorSizeA()>{},
number<1>{}); number<1>{});
} }
else else
...@@ -299,7 +295,7 @@ struct GemmKernel ...@@ -299,7 +295,7 @@ struct GemmKernel
a_ptr, a_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.M), make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(kargs.stride_A, 1), make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{}, number<GemmPipeline::GetVectorSizeA()>{},
number<1>{}); number<1>{});
} }
}(); }();
...@@ -311,7 +307,7 @@ struct GemmKernel ...@@ -311,7 +307,7 @@ struct GemmKernel
b_ptr, b_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.N), make_tuple(splitk_batch_offset.splitted_k, kargs.N),
make_tuple(kargs.stride_B, 1), make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{}, number<GemmPipeline::GetVectorSizeB()>{},
number<1>{}); number<1>{});
} }
else else
...@@ -320,7 +316,7 @@ struct GemmKernel ...@@ -320,7 +316,7 @@ struct GemmKernel
b_ptr, b_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k), make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_B, 1), make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{}, number<GemmPipeline::GetVectorSizeB()>{},
number<1>{}); number<1>{});
} }
}(); }();
...@@ -333,7 +329,7 @@ struct GemmKernel ...@@ -333,7 +329,7 @@ struct GemmKernel
c_ptr, c_ptr,
make_tuple(kargs.M, kargs.N), make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1), make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{}, number<EpiloguePipeline::template GetVectorSizeC<CDataType>()>{},
number<1>{}); number<1>{});
} }
else else
...@@ -501,22 +497,14 @@ struct GemmKernel ...@@ -501,22 +497,14 @@ struct GemmKernel
// Run Epilogue Pipeline // Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2); auto& c_block_window = gemm_tile_windows.at(I2);
constexpr bool is_output_c_reg_transposed = EpiloguePipeline{}
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC(); .template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
if constexpr((DstInMemOp == memory_operation_enum::set) || (sizeof(CDataType) > 2) || c_block_window, c_block_tile, smem_ptr);
(GemmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed))
{
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile);
}
} }
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
{ {
const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y); const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
...@@ -531,14 +519,20 @@ struct GemmKernel ...@@ -531,14 +519,20 @@ struct GemmKernel
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
if(kargs.KBatch == 1) if(kargs.k_batch == 1)
{ {
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
} }
else else
{ {
RunGemm<memory_operation_enum::atomic_add>( // Do not compile in case where we have unsupported
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); // VectorSizeC & data type configuration.
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm<memory_operation_enum::atomic_add>(
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
} }
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
/**
* @file
* GemmTilePartitioner allows customized mapping between a workgroup and the C-tile it computes.
*/
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
namespace ck_tile { namespace ck_tile {
/** @brief Struct representing 2D block index mapping into 3D output tile space. */ /**
* @brief Class providing 2D workgroup index mapping into 2D output GEMM C-tile space.
*
*/
template <typename BlockGemmShapeType> template <typename BlockGemmShapeType>
struct GemmTile2DPartitioner struct GemmTile2DPartitioner
{ {
...@@ -17,21 +25,32 @@ struct GemmTile2DPartitioner ...@@ -17,21 +25,32 @@ struct GemmTile2DPartitioner
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
/** @brief Returns 3D grid size. */ CK_TILE_HOST_DEVICE GemmTile2DPartitioner() noexcept = delete;
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t batch_size) noexcept( CK_TILE_HOST_DEVICE GemmTile2DPartitioner([[maybe_unused]] index_t M,
noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3 [[maybe_unused]] index_t N) noexcept;
/**
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return dim3 Structure holding grid's X,Y and Z dimensions.
*/
CK_TILE_HOST static auto
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
{ {
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock; const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock; const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
const index_t GridDimZ = batch_size; return dim3(GridDimX, GridDimY, 1);
return dim3(GridDimX, GridDimY, GridDimZ);
} }
/** /**
* @brief Returns the number of loops. * @brief Calculate number of loop iterations over GEMM's K dimension.
* @param [in] K is dimension *
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/ */
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) noexcept -> index_t CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
{ {
return integer_divide_ceil(K, KPerBlock); return integer_divide_ceil(K, KPerBlock);
} }
...@@ -42,8 +61,15 @@ struct GemmTile2DPartitioner ...@@ -42,8 +61,15 @@ struct GemmTile2DPartitioner
* @param [in] blockIdy is blockIdx.y * @param [in] blockIdy is blockIdx.y
* @return Returns the output tile indexes. * @return Returns the output tile indexes.
*/ */
CK_TILE_DEVICE static constexpr auto GetOutputTileIndex(index_t blockIdx,
index_t blockIdy) noexcept /**
* @brief Calculate workgroup 2D index mapping into 2D output C-tile space.
*
* @param blockIdx WGP's X index.
* @param blockIdy WGP's Y index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept
-> const tuple<index_t, index_t> -> const tuple<index_t, index_t>
{ {
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx); const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx);
...@@ -53,61 +79,71 @@ struct GemmTile2DPartitioner ...@@ -53,61 +79,71 @@ struct GemmTile2DPartitioner
}; };
/** /**
* @brief Struct representing 1D block index mapping into 2D output tile space. * @brief Class providing 1D WGP index mapping into 2D output C-tile space.
*
* @tparam BlockGemmShape_ A class providing basic GEMM parameters. \link TileGemmShape
*/ */
template <typename BlockGemmShapeType> template <typename BlockGemmShape_>
struct GemmTile1DPartitioner struct GemmTile1DPartitioner
{ {
using BlockGemmShape = remove_cvref_t<BlockGemmShapeType>; using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
/** @brief delete default ctr with no any object */ CK_TILE_HOST_DEVICE GemmTile1DPartitioner() noexcept = delete;
constexpr GemmTile1DPartitioner() noexcept = delete;
/** @brief constructs an object that does contain a N value. */
constexpr GemmTile1DPartitioner(index_t N) noexcept { N_ = N; }
/** @brief Returns 1D grid size. */ /**
CK_TILE_HOST static constexpr auto * @brief Construct a new GemmTile1DPartitioner object.
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3 *
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
*/
CK_TILE_HOST_DEVICE GemmTile1DPartitioner([[maybe_unused]] index_t M, index_t N) noexcept
{ {
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock; N_ = N;
const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
return dim3(GridDimX * GridDimY, 1, 1);
} }
/** /**
* @brief Returns the number of blocks in N. * @brief Calculates GEMM kernel grid size.
* @param [in] N is dimension *
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return dim3 Structure holding grid's X,Y and Z dimensions.
*/ */
CK_TILE_HOST_DEVICE static constexpr auto GetNBlock(index_t N) noexcept -> index_t CK_TILE_HOST static auto
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
{ {
return integer_divide_ceil(N, NPerBlock); const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
return GridDimX * GridDimY;
} }
/** /**
* @brief Returns the number of loops. * @brief Calculate number of loop iterations over GEMM's K dimension.
* @param [in] K is dimension *
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/ */
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) noexcept -> index_t CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
{ {
return integer_divide_ceil(K, KPerBlock); return integer_divide_ceil(K, KPerBlock);
} }
/** /**
* @brief The function returns 2D output tile space. * @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
* @param [in] blockIdx is blockIdx.x - block_start. *
* */ * @param blockIdx WGP's index.
CK_TILE_DEVICE static constexpr auto GetOutputTileIndex(index_t blockIdx) noexcept * @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx) noexcept
-> const tuple<index_t, index_t> -> const tuple<index_t, index_t>
{ {
const index_t NBlock = GetNBlock(N_); const index_t NBlocks = integer_divide_ceil(N_, NPerBlock);
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlock); const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlocks);
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - (iM)*NBlock); const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - iM * NBlocks);
return make_tuple(iM, iN); return make_tuple(iM, iN);
} }
...@@ -141,21 +177,176 @@ struct HasFnOneArgImpl<T, std::void_t<decltype(std::declval<T>().GetOutputTileIn ...@@ -141,21 +177,176 @@ struct HasFnOneArgImpl<T, std::void_t<decltype(std::declval<T>().GetOutputTileIn
* enable-if `GetOutputTileIndex`-fn is std::true_type when `GetOutputTileIndex`-fn is well-formed, * enable-if `GetOutputTileIndex`-fn is std::true_type when `GetOutputTileIndex`-fn is well-formed,
* otherwise std::false_type. * otherwise std::false_type.
*/ */
template <typename PartitionerFn, template <typename TilePartitioner,
typename = typename std::enable_if_t<HasFnOneArgImpl<PartitionerFn>{}>> typename = typename std::enable_if_t<HasFnOneArgImpl<TilePartitioner>{}>>
struct OffsettedTile1DPartitioner struct OffsettedTile1DPartitioner
{ {
/** /**
* @brief The function subtracts the block's start (offset) from 1D raw-indexes. * @brief The function subtracts the block's start (offset) from 1D raw-indexes.
* @param [in] block_start is `blockIdx.x - block_start`. * @param [in] block_start Workgroup offset.
* @return Returns a `tuple` [Im, In] shifted index, used to shift 1d-tile index. * @param [in] M Gemm's M dimension.
* @param [in] N Gemm's N dimension.
* @return Returns a `tuple` [Im, In] with shifted index.
*/ */
[[nodiscard]] CK_TILE_DEVICE static constexpr auto GetOffsetedTileIndex(index_t block_start, [[nodiscard]] CK_TILE_DEVICE static auto
index_t N) noexcept GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept
-> const tuple<index_t, index_t> -> const tuple<index_t, index_t>
{ {
const auto [iM, iN] = PartitionerFn(N).GetOutputTileIndex(blockIdx.x - block_start); const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(blockIdx.x - block_start);
return make_tuple(iM, iN); return make_tuple(iM, iN);
} }
}; };
/**
* @brief Class mapping 1D block index into 2D output tile space.
*
* @note It groups spatially workgroups in order to better utilize caches.
* It is using grouped Rows of column-vectors WGP pattern. It's optimized
* for gfx94x-like multiple-die chip.
*
* @tparam GroupNum - The number of big groups.
* @tparam M01 - The number of groups in M dim within spatially local WGPs,
*
*/
template <typename BlockGemmShapeType, index_t GroupNum, index_t M01>
struct GemmSpatiallyLocalTilePartitioner
{
using BlockGemmShape = remove_cvref_t<BlockGemmShapeType>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
CK_TILE_HOST_DEVICE GemmSpatiallyLocalTilePartitioner() noexcept = delete;
CK_TILE_HOST_DEVICE GemmSpatiallyLocalTilePartitioner(index_t M_, index_t N_) noexcept
: M(M_), N(N_)
{
}
/**
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return index_t A total number of workgroups.
*/
CK_TILE_HOST static auto
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
{
const index_t GridDimX = integer_divide_ceil(M, MPerBlock);
const index_t GridDimY = integer_divide_ceil(N, NPerBlock);
return GridDimX * GridDimY;
}
/**
* @brief Calculate number of loop iterations over GEMM's K dimension.
*
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/
CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
{
return integer_divide_ceil(K, KPerBlock);
}
/**
* @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
*
* @param [in] block_1d_id WGP's index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE auto GetOutputTileIndex(index_t block_1d_id) noexcept
-> const tuple<index_t, index_t>
{
const auto M0 = integer_divide_ceil(M, MPerBlock);
const auto N0 = integer_divide_ceil(N, NPerBlock);
if(M0 == 1)
{
return make_tuple(0, block_1d_id);
}
else if(N0 == 1)
{
return make_tuple(block_1d_id, 0);
}
// block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
else
{
const auto group_size = integer_divide_ceil(M0 * N0, GroupNum);
const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
const auto group_id_y = block_1d_id / GroupNum;
const auto group_id_x = block_1d_id - group_id_y * GroupNum;
const auto remap_block_1d_id =
group_id_x <= big_group_num
? group_id_x * group_size + group_id_y
: group_id_x * group_size + big_group_num - group_id_x + group_id_y;
const index_t idx_M0 = remap_block_1d_id / N0;
const index_t idx_N0 = remap_block_1d_id - idx_M0 * N0;
const index_t M0_tmp = M0 / M01;
const index_t M0_mod_M01 = M0 - M0_tmp * M01;
const auto M01_adapt = (idx_M0 < M0 - M0_mod_M01) ? M01 : M0_mod_M01;
const index_t idx_M00 = idx_M0 / M01;
const index_t idx_M01 = idx_M0 - idx_M00 * M01;
const index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
const index_t N_out = idx_N0_M01_local / M01_adapt;
const index_t idx_loc_mod_M01 = idx_N0_M01_local - N_out * M01_adapt;
return make_tuple(idx_loc_mod_M01 + idx_M00 * M01, N_out);
}
}
private:
index_t M;
index_t N;
};
} // namespace ck_tile } // namespace ck_tile
...@@ -64,6 +64,18 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -64,6 +64,18 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
} }
}; };
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
using P_ = GemmPipeline;
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>,
concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
// clang-format on
}
__host__ static auto GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs) __host__ static auto GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
-> std::size_t -> std::size_t
{ {
...@@ -77,8 +89,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -77,8 +89,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
index_t grid_size = 0; index_t grid_size = 0;
for(const auto& it_desc : gemm_descs) for(const auto& it_desc : gemm_descs)
{ {
const auto dim3 = TilePartitioner::GridSize(it_desc.M, it_desc.N); const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N);
grid_size += dim3.x * dim3.y * 1; grid_size += local_grid_size * it_desc.k_batch;
} }
return dim3(grid_size, 1, 1); return dim3(grid_size, 1, 1);
} }
...@@ -106,8 +118,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -106,8 +118,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
const index_t stride_b = gemm_descs[i].stride_B; const index_t stride_b = gemm_descs[i].stride_B;
const index_t stride_c = gemm_descs[i].stride_C; const index_t stride_c = gemm_descs[i].stride_C;
const auto dim3 = TilePartitioner::GridSize(M, N); const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
const index_t grid_size_grp = dim3.x;
const index_t block_start = grid_size; const index_t block_start = grid_size;
const index_t block_end = grid_size + grid_size_grp; const index_t block_end = grid_size + grid_size_grp;
...@@ -138,8 +149,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -138,8 +149,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE void Run(const GemmTransKernelArg& kargs) const CK_TILE_DEVICE void Run(const GemmTransKernelArg& kargs) const
{ {
const auto [iM, iN] = const auto [iM, iN] = OffsetTile1DPartitioner::GetOffsetedTileIndex(
OffsetTile1DPartitioner::GetOffsetedTileIndex(kargs.block_start, kargs.group_karg.N); kargs.block_start, kargs.group_karg.M, kargs.group_karg.N);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
......
...@@ -21,6 +21,8 @@ struct GemmPipelineAgBgCrImplBase ...@@ -21,6 +21,8 @@ struct GemmPipelineAgBgCrImplBase
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
template <typename DstBlockTile, typename SrcTileWindow, typename DramTileWindowStep> template <typename DstBlockTile, typename SrcTileWindow, typename DramTileWindowStep>
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile, CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window, SrcTileWindow& dram_tile_window,
......
...@@ -3,10 +3,14 @@ ...@@ -3,10 +3,14 @@
#pragma once #pragma once
#include <string>
#include <sstream>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -20,6 +24,8 @@ struct BaseGemmPipelineAgBgCrCompV3 ...@@ -20,6 +24,8 @@ struct BaseGemmPipelineAgBgCrCompV3
static constexpr index_t PrefillStages = 1; static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1; static constexpr index_t GlobalBufferNum = 1;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
{ {
return num_loop > PrefetchStages; return num_loop > PrefetchStages;
...@@ -62,9 +68,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -62,9 +68,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Policy::template GetVectorSizeA<Problem>(); static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t VectorSizeB = Policy::template GetVectorSizeB<Problem>(); static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t VectorSizeC = Policy::template GetVectorSizeC<Problem>(); static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
...@@ -76,14 +82,68 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -76,14 +82,68 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
using Base::PrefetchStages; using Base::PrefetchStages;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AgBgCrCompV3", BlockSize,
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', kPadM, kPadN, kPadK));
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() CK_TILE_HOST static std::string Print()
{ {
return Policy::template IsTransposeC<Problem>(); constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
// Below should be equal to AK1|BK1
constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
auto str = std::stringstream{};
str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << "\n"
<< "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
<< "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
<< "\n"
<< "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
} }
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
...@@ -98,29 +158,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -98,29 +158,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
CK_TILE_DEVICE static constexpr auto HotLoopScheduler() CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{ {
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{}); constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{}); constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{}); constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64; constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
constexpr index_t A_LDS_Read_Width = KPerXDL; // Below should be equal to AK1|BK1
constexpr index_t B_LDS_Read_Width = KPerXDL; constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_Buffer_Load_Inst_Num = constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * VectorSizeA); MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num = constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * VectorSizeB); NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL); constexpr index_t A_LDS_Write_Inst_Num =
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL); MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
constexpr index_t A_LDS_Read_Inst_Num = constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL); WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num = constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL); WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) / (BlockSize / WaveSize) /
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -20,6 +21,8 @@ struct BaseGemmPipelineAgBgCrMem ...@@ -20,6 +21,8 @@ struct BaseGemmPipelineAgBgCrMem
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
...@@ -88,7 +91,7 @@ struct BaseGemmPipelineAgBgCrMem ...@@ -88,7 +91,7 @@ struct BaseGemmPipelineAgBgCrMem
// LocalPreFillStages: 1 // LocalPreFillStages: 1
// LocalPreFetchStages: 0 // LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1 // LocalSharedMemoryBuffer: 1
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy> template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{ {
using Base = BaseGemmPipelineAgBgCrMem<Problem>; using Base = BaseGemmPipelineAgBgCrMem<Problem>;
...@@ -113,9 +116,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -113,9 +116,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Policy::template GetVectorSizeA<Problem>(); static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t VectorSizeB = Policy::template GetVectorSizeB<Problem>(); static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t VectorSizeC = Policy::template GetVectorSizeC<Problem>(); static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
...@@ -126,6 +129,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -126,6 +129,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static constexpr auto TailNum = Problem::TailNum; static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler; static constexpr auto Scheduler = Problem::Scheduler;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AgBgCrMe",
concat('x', MPerBlock, NPerBlock, KPerBlock),
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', kPadM, kPadN, kPadK));
// clang-format on
}
using Base::PrefetchStages; using Base::PrefetchStages;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
...@@ -133,11 +146,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -133,11 +146,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Policy::template IsTransposeC<Problem>();
}
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase struct PipelineImpl : public PipelineImplBase
{ {
...@@ -168,11 +176,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -168,11 +176,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate " "A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"); "([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && constexpr bool is_a_col_major =
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}], constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"); static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// Definitions of all needed tiles // Definitions of all needed tiles
...@@ -216,25 +235,59 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -216,25 +235,59 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array<ABlockTile, PrefetchStages> a_block_tiles; tuple_array<ABlockTile, PrefetchStages> a_block_tiles;
tuple_array<BBlockTile, PrefetchStages> b_block_tiles; tuple_array<BBlockTile, PrefetchStages> b_block_tiles;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// ----------------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------------
// Gemm pipeline start // Gemm pipeline start
// prefetch // prefetch
// global read 0 // global read 0
Base::GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window); Base::GlobalPrefetch(
Base::GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window); a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step);
// initialize C // initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0 // LDS write 0
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); if constexpr(is_a_col_major)
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); {
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
}
// Global prefetch [1, PrefetchStages] // Global prefetch [1, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window); Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window); a_copy_dram_window,
a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window,
b_dram_tile_window_step);
}); });
// main body // main body
...@@ -250,19 +303,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -250,19 +303,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds(); block_sync_lds();
Base::LocalPrefill( if constexpr(is_a_col_major)
a_copy_lds_window, {
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
a_element_func); Policy::template MakeShuffledARegTileDistribution<Problem>());
Base::LocalPrefill( transpose_tile2d(
b_copy_lds_window, a_shuffle_tmp,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
b_element_func); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(
a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(
b_shuffle_tmp,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(
b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func);
}
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window); a_copy_dram_window,
a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window); b_copy_dram_window,
b_dram_tile_window_step);
}); });
i += PrefetchStages; i += PrefetchStages;
...@@ -278,12 +357,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -278,12 +357,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds(); block_sync_lds();
Base::LocalPrefill(a_copy_lds_window, if constexpr(is_a_col_major)
a_block_tiles.get(number<prefetch_idx>{}), {
a_element_func); auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Base::LocalPrefill(b_copy_lds_window, Policy::template MakeShuffledARegTileDistribution<Problem>());
b_block_tiles.get(number<prefetch_idx>{}), transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number<prefetch_idx>{}));
b_element_func); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<prefetch_idx>{}),
a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number<prefetch_idx>{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<prefetch_idx>{}),
b_element_func);
}
}); });
block_sync_lds(); block_sync_lds();
...@@ -355,11 +454,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -355,11 +454,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate " "A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"); "([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && constexpr bool is_a_col_major =
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}], constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"); static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// Definitions of all needed tiles // Definitions of all needed tiles
...@@ -403,25 +513,58 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -403,25 +513,58 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array<ABlockTile, PrefetchStages> a_block_tiles; tuple_array<ABlockTile, PrefetchStages> a_block_tiles;
tuple_array<BBlockTile, PrefetchStages> b_block_tiles; tuple_array<BBlockTile, PrefetchStages> b_block_tiles;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// ----------------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------------
// Gemm pipeline start // Gemm pipeline start
// prefetch // prefetch
// global read 0 // global read 0
Base::GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window); Base::GlobalPrefetch(
Base::GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window); a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step);
// initialize C // initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0 // LDS write 0
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); if constexpr(is_a_col_major)
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); {
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
}
// Global prefetch [1, PrefetchStages] // Global prefetch [1, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window); Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window); a_copy_dram_window,
a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window,
b_dram_tile_window_step);
}); });
// main body // main body
...@@ -435,19 +578,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -435,19 +578,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
// no second block_sync_lds because it's interwave // no second block_sync_lds because it's interwave
Base::LocalPrefill( if constexpr(is_a_col_major)
a_copy_lds_window, {
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
a_element_func); Policy::template MakeShuffledARegTileDistribution<Problem>());
Base::LocalPrefill( transpose_tile2d(
b_copy_lds_window, a_shuffle_tmp,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
b_element_func); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(
a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(
b_shuffle_tmp,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(
b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func);
}
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window); a_copy_dram_window,
a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window); b_copy_dram_window,
b_dram_tile_window_step);
}); });
i += PrefetchStages; i += PrefetchStages;
...@@ -460,12 +629,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -460,12 +629,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
// no second block_sync_lds because it's interwave // no second block_sync_lds because it's interwave
Base::LocalPrefill(a_copy_lds_window, if constexpr(is_a_col_major)
a_block_tiles.get(number<prefetch_idx>{}), {
a_element_func); auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Base::LocalPrefill(b_copy_lds_window, Policy::template MakeShuffledARegTileDistribution<Problem>());
b_block_tiles.get(number<prefetch_idx>{}), transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number<prefetch_idx>{}));
b_element_func); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<prefetch_idx>{}),
a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number<prefetch_idx>{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<prefetch_idx>{}),
b_element_func);
}
}); });
block_sync_lds(); block_sync_lds();
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <ostream> #include <ostream>
#include <sstream>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -31,21 +32,33 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -31,21 +32,33 @@ struct GemmPipelineAGmemBGmemCRegV1
static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK; static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Problem::VectorSizeA; static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
static constexpr index_t VectorSizeB = Problem::VectorSizeB; static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK; static constexpr bool kPadK = Problem::kPadK;
static constexpr index_t kLdsAlignmentInBytes = 16;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AGmemBGmemCRegV1",
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', kPadM, kPadN, kPadK));
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <typename ADramBlockWindowTmp, template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp, typename BDramBlockWindowTmp,
typename AElementFunction, typename AElementFunction,
...@@ -75,8 +88,9 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -75,8 +88,9 @@ struct GemmPipelineAGmemBGmemCRegV1
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc); auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
constexpr index_t a_lds_block_space_size_aligned = constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
16; kLdsAlignmentInBytes) *
kLdsAlignmentInBytes;
// B tile in LDS // B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>( BDataType* p_b_lds = static_cast<BDataType*>(
......
...@@ -16,8 +16,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -16,8 +16,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{}; static constexpr auto I2 = number<2>{};
static constexpr bool TransposeC = true;
// 3d + padding // 3d + padding
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
...@@ -383,8 +381,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -383,8 +381,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
} }
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
...@@ -397,7 +393,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -397,7 +393,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
WarpTile::at(I0), WarpTile::at(I0),
WarpTile::at(I1), WarpTile::at(I1),
WarpTile::at(I2), WarpTile::at(I2),
TransposeC>; Problem::TransposeC>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType, using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType, typename Problem::BDataType,
typename Problem::CDataType, typename Problem::CDataType,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -25,6 +26,15 @@ struct GemmPipelineAGmemBGmemCRegV2 ...@@ -25,6 +26,15 @@ struct GemmPipelineAGmemBGmemCRegV2
static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK; static constexpr index_t kKPerBlock = BlockGemmShape::kK;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AGmemBGmemCRegV2",
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, kBlockSize));
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{ {
return integer_divide_ceil( return integer_divide_ceil(
...@@ -36,8 +46,6 @@ struct GemmPipelineAGmemBGmemCRegV2 ...@@ -36,8 +46,6 @@ struct GemmPipelineAGmemBGmemCRegV2
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <typename ADramBlockWindowTmp, template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp, typename BDramBlockWindowTmp,
typename AElementFunction, typename AElementFunction,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -27,15 +28,27 @@ struct GemmPipelineProblemBase ...@@ -27,15 +28,27 @@ struct GemmPipelineProblemBase
using BLayout = remove_cvref_t<typename Traits::BLayout>; using BLayout = remove_cvref_t<typename Traits::BLayout>;
using CLayout = remove_cvref_t<typename Traits::CLayout>; using CLayout = remove_cvref_t<typename Traits::CLayout>;
static constexpr bool TransposeC = Traits::TransposeC;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadM = Traits::kPadM; static constexpr bool kPadM = Traits::kPadM;
static constexpr bool kPadN = Traits::kPadN; static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK; static constexpr bool kPadK = Traits::kPadK;
static constexpr auto Scheduler = GemmPipelineScheduler::Default; static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize; static constexpr index_t VectorLoadSize = Traits::_VectorSize;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm_problem",
concat('x', VectorLoadSize, kBlockSize),
concat('x', kPadM, kPadN, kPadK),
Scheduler);
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{ {
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
...@@ -111,7 +124,6 @@ struct GemmPipelineProblemBase ...@@ -111,7 +124,6 @@ struct GemmPipelineProblemBase
return kPadK ? 1 : GetAlignmentB(); return kPadK ? 1 : GetAlignmentB();
} }
}(); }();
static constexpr index_t VectorSizeC = []() { static constexpr index_t VectorSizeC = []() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
......
...@@ -185,7 +185,6 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -185,7 +185,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{ {
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
...@@ -519,7 +518,7 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -519,7 +518,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
using ALayout = remove_cvref_t<typename Problem::ALayout>; using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>); static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kN; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>(); constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
...@@ -549,12 +548,6 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -549,12 +548,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Problem::TransposeC;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -19,6 +20,16 @@ struct TileGemmShape ...@@ -19,6 +20,16 @@ struct TileGemmShape
static constexpr index_t kM = BlockTile::at(number<0>{}); static constexpr index_t kM = BlockTile::at(number<0>{});
static constexpr index_t kN = BlockTile::at(number<1>{}); static constexpr index_t kN = BlockTile::at(number<1>{});
static constexpr index_t kK = BlockTile::at(number<2>{}); static constexpr index_t kK = BlockTile::at(number<2>{});
CK_TILE_HOST static std::string GetName()
{
// clang-format off
return concat('_', "tile_gemm_shape",
concat('x', kM, kN, kK, NumWarps),
concat('x', BlockWarps::at(number<0>{}), BlockWarps::at(number<1>{}), BlockWarps::at(number<2>{})),
concat('x', (WarpTile::at(number<0>{})), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})));
// clang-format on
}
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -8,3 +8,4 @@ ...@@ -8,3 +8,4 @@
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
...@@ -11,3 +11,4 @@ ...@@ -11,3 +11,4 @@
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
...@@ -8,3 +8,4 @@ ...@@ -8,3 +8,4 @@
#include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp" #include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
...@@ -7,3 +7,4 @@ ...@@ -7,3 +7,4 @@
#include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp" #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
...@@ -9,3 +9,4 @@ ...@@ -9,3 +9,4 @@
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp" #include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
...@@ -11,3 +11,4 @@ ...@@ -11,3 +11,4 @@
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment