Unverified Commit 25e2e0f0 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

[CK TILE] Implement cschuflle algorithm (#1842)

* [CK TILE] Implement cschuflle algorithm

* Rebase

* Vector store size fixes

* fixes

* Fixes

* fixes

* fmha fix

* fixes

* fixes of fixes
parent c5fff071
...@@ -20,10 +20,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -20,10 +20,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr bool kPadN = false; constexpr bool kPadN = false;
constexpr bool kPadK = false; constexpr bool kPadK = false;
constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
// This part comes from the Codegen // This part comes from the Codegen
...@@ -39,11 +35,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -39,11 +35,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
using CodegenGemmShape = using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>, ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
...@@ -51,26 +42,24 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -51,26 +42,24 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadM,
kPadN,
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using CodegenGemmTraits = using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
CLayout,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
...@@ -60,9 +60,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -60,9 +60,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>; using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile:: using GemmUniversalTraits = ck_tile::
TileGemmUniversalTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, TransposeC>; TileGemmUniversalTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, TransposeC>;
...@@ -95,6 +92,19 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -95,6 +92,19 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
using GemmPipeline = using GemmPipeline =
GEMM_PIPELINE<UniversalGemmProblem, ck_tile::UniversalGemmPipelineAgBgCrPolicy>; GEMM_PIPELINE<UniversalGemmProblem, ck_tile::UniversalGemmPipelineAgBgCrPolicy>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args); auto kargs = Kernel::MakeKernelArgs(args);
......
...@@ -22,9 +22,6 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -22,9 +22,6 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
constexpr bool kPadM = false; constexpr bool kPadM = false;
constexpr bool kPadN = false; constexpr bool kPadN = false;
constexpr bool kPadK = false; constexpr bool kPadK = false;
constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
...@@ -41,11 +38,6 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -41,11 +38,6 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
using CodegenGemmShape = using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>, ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
...@@ -53,26 +45,24 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -53,26 +45,24 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadM,
kPadN,
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using CodegenGemmTraits = using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
CLayout,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
...@@ -23,9 +23,6 @@ struct GroupedGemmKernelParam ...@@ -23,9 +23,6 @@ struct GroupedGemmKernelParam
static const bool kPadM = false; static const bool kPadM = false;
static const bool kPadN = false; static const bool kPadN = false;
static const bool kPadK = false; static const bool kPadK = false;
static const bool kTilePermute = false;
static const ck_tile::index_t kOutputRank = 2;
static const int kBlockPerCu = 1; static const int kBlockPerCu = 1;
static const ck_tile::index_t M_Tile = 128; static const ck_tile::index_t M_Tile = 128;
...@@ -54,24 +51,6 @@ using CodegenGemmShape = ...@@ -54,24 +51,6 @@ using CodegenGemmShape =
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
template <typename CLayout>
using GemmEpilogue = std::conditional_t<
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
GroupedGemmKernelParam::kPadM,
GroupedGemmKernelParam::kPadN,
GroupedGemmKernelParam::kTilePermute,
GroupedGemmKernelParam::kOutputRank,
1,
0,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<AccDataType,
CDataType,
GroupedGemmKernelParam::kPadM,
GroupedGemmKernelParam::kPadN>>>;
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmTraits = ck_tile::TileGemmTraits<GroupedGemmKernelParam::kPadM, using CodegenGemmTraits = ck_tile::TileGemmTraits<GroupedGemmKernelParam::kPadM,
GroupedGemmKernelParam::kPadN, GroupedGemmKernelParam::kPadN,
...@@ -92,10 +71,25 @@ template <typename ALayout, typename BLayout, typename CLayout> ...@@ -92,10 +71,25 @@ template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmPipeline = using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>>; ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>>;
template <typename ALayout, typename BLayout, typename CLayout>
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
AccDataType,
CDataType,
CLayout,
CodegenPipelineProblem<ALayout, BLayout, CLayout>::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemmKernelParam::M_Warp,
GroupedGemmKernelParam::N_Warp,
GroupedGemmKernelParam::M_Warp_Tile,
GroupedGemmKernelParam::N_Warp_Tile,
GroupedGemmKernelParam::K_Warp_Tile,
CodegenPipelineProblem<ALayout, BLayout, CLayout>::TransposeC>>;
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
CodegenGemmPipeline<ALayout, BLayout, CLayout>, CodegenGemmPipeline<ALayout, BLayout, CLayout>,
GemmEpilogue<CLayout>>; GemmEpilogue<ALayout, BLayout, CLayout>>;
}; // namespace }; // namespace
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs) std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#define CK_TILE_MAX_RANK 5 #include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile { namespace ck_tile {
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
// memory.
template <typename AccDataType_, template <typename AccDataType_,
typename ODataType_, typename ODataType_,
bool kPadM_, typename CLayout_,
bool kPadN_, index_t kBlockSize_,
bool kTilePermute_, index_t kM_,
index_t kRank_, index_t kN_,
index_t kPerm0, index_t kMWave_,
index_t kPerm1, index_t kNWave_,
index_t TileSize0, index_t kMPerXdl_,
index_t TileSize1, index_t kNPerXdl_,
index_t kPerm2 = 0, index_t kKPerXdl_,
index_t kPerm3 = 0, bool isCTransposed_>
index_t kPerm4 = 0,
index_t TileSize2 = 0,
index_t TileSize3 = 0,
index_t TileSize4 = 0>
struct CShuffleEpilogueProblem struct CShuffleEpilogueProblem
{ {
using AccDataType = remove_cvref_t<AccDataType_>; using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>; using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_; using CLayout = remove_cvref_t<CLayout_>;
static constexpr bool kPadN = kPadN_; static constexpr index_t kBlockSize = kBlockSize_;
static constexpr bool kTilePermute = kTilePermute_; static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kRank = kRank_; static constexpr index_t kNPerBlock = kN_;
static constexpr index_t kPerm[CK_TILE_MAX_RANK] = {kPerm0, kPerm1, kPerm2, kPerm3, kPerm4}; static constexpr index_t kMWave = kMWave_;
static constexpr index_t tile_sizes[CK_TILE_MAX_RANK] = { static constexpr index_t kNWave = kNWave_;
TileSize0, TileSize1, TileSize2, TileSize3, TileSize4}; static constexpr index_t kMPerXdl = kMPerXdl_;
static constexpr index_t kNPerXdl = kNPerXdl_;
static constexpr index_t kKPerXdl = kKPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
}; };
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
...@@ -46,149 +43,147 @@ struct CShuffleEpilogue ...@@ -46,149 +43,147 @@ struct CShuffleEpilogue
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM; using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr bool kPadN = Problem::kPadN; static constexpr index_t kBlockSize = Problem::kBlockSize;
const index_t* kPerm = Problem::kPerm; static constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr bool kTilePermute = Problem::kTilePermute; static constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t kRank = Problem::kRank; static constexpr index_t kMWave = Problem::kMWave;
const index_t* tile_sizes = Problem::tile_sizes; static constexpr index_t kNWave = Problem::kNWave;
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
// No additional shared memory needed static constexpr index_t kNPerXdl = Problem::kNPerXdl;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } static constexpr index_t kKPerXdl = Problem::kKPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() static constexpr index_t kMPerIteration = kMPerXdl * kMWave;
{ static constexpr index_t kNPerIteration = kNPerXdl * kNWave;
// TODO: At now CShuffle doesn't allow to vector store after permute.
// It should be fixed and this function should return true. using WG = WarpGemmMfmaDispatcher<ODataType,
return false; ODataType,
AccDataType,
kMPerXdl,
kNPerXdl,
kKPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
/**
* @brief Get the vector store size for C tensor.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
constexpr index_t MaxVectorStoreSize = 16;
return MaxVectorStoreSize / sizeof(ODataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
{
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
make_tuple(number<kNWave * kNPerXdl>{}, number<1>{}));
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
make_tuple(number<1>{}, number<kMWave * kMPerXdl>{}));
} }
else
template <typename OAccTile>
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile)
{
using DataType = typename OAccTile::DataType;
// Get thread buffer
auto& thread_buf = o_acc_tile.get_thread_buffer();
// Create a temporary buffer to hold the permuted data
thread_buffer<DataType, OAccTile::kThreadElementSpaceSize> permuted_thread_buf;
// Get the lengths of each dimension
auto thread_tensor_lengths = o_acc_tile.get_lengths();
// Total number of elements
index_t total_elements = OAccTile::kThreadElementSpaceSize;
// Iterate over all elements
for(index_t linear_idx = 0; linear_idx < total_elements; ++linear_idx)
{ {
// Convert linear index to multi-dimensional indices static_assert(false, "Unsupported CLayout!");
array<index_t, kRank> indices; }
index_t remaining = linear_idx;
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
indices(rev_i) = remaining % thread_tensor_lengths.get(number<rev_i>{});
remaining /= thread_tensor_lengths.get(number<rev_i>{});
});
// Apply the permutation
array<index_t, kRank> permuted_indices;
static_for<0, kRank, 1>{}(
[&](auto i) { permuted_indices(i) = indices.get(number<Problem::kPerm[i]>{}); });
// Compute offsets
index_t dst_offset = 0;
index_t stride = 1;
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
dst_offset += permuted_indices[rev_i] * stride;
stride *= thread_tensor_lengths.get(number<rev_i>{});
});
// Move the data
permuted_thread_buf(dst_offset) = thread_buf[linear_idx];
} }
// Copy the permuted data back to the original thread buffer CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
for(index_t i = 0; i < total_elements; ++i)
{ {
thread_buf.set_as(i, permuted_thread_buf.get(i)); return kMWave * kNWave * kMPerXdl * kNPerXdl * sizeof(ODataType);
}
} }
template <typename ODramWindowTmp, template <typename ODramWindow,
typename OAccTile, typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set> memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile) CK_TILE_DEVICE auto
{ operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem)
const auto& current_window_origin = o_dram_window_tmp.get_window_origin(); {
// Compute the tile coordinates by dividing the window origin by the tile sizes const index_t iMWarp = get_warp_id() / kNWave;
index_t tile_coords[CK_TILE_MAX_RANK] = {0}; const index_t iNWarp = get_warp_id() - iMWarp * kNWave;
for(index_t i = 0; i < kRank; ++i)
{ constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
tile_coords[i] = current_window_origin[i] / tile_sizes[i]; auto o_lds_block = make_tensor_view<address_space_enum::lds>(
// printf("The tile_coord is: %d", tile_coords[i]); static_cast<ODataType*>(p_smem), lds_block_desc);
} auto in_lds_window =
make_tile_window(o_lds_block,
make_tuple(number<kMPerXdl>{}, number<kNPerXdl>{}),
{number<kMPerXdl>{} * iMWarp, number<kNPerXdl>{} * iNWarp});
auto out_lds_window =
make_tile_window(o_lds_block,
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
{0, 0});
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<kMPerXdl * kMWave, kNPerXdl * kNWave>>;
constexpr index_t num_access = SFC::get_num_of_access();
using TileEncodingPattern =
TileDistributionEncodingPattern2D<kBlockSize,
kMPerIteration,
kNPerIteration,
GetVectorSizeC(),
tile_distribution_pattern::thread_raked>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
CWarpTensor c_warp_in_tensor;
static_for<0, num_access, 1>{}([&](auto iAccess) {
constexpr auto idx_y_start = SFC::get_index(iAccess);
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (kMPerXdl * kMWave)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (kNPerXdl * kNWave)>{};
c_warp_in_tensor.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
const auto c_warp_in_tensor_casted = cast_tile<ODataType>(c_warp_in_tensor);
block_sync_lds();
store_tile(in_lds_window, c_warp_in_tensor_casted);
block_sync_lds();
const auto c_out_tensor =
load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
// Apply the permutation to the tile coordinates
index_t permuted_tile_coords[CK_TILE_MAX_RANK];
for(index_t i = 0; i < kRank; ++i)
{
permuted_tile_coords[i] = tile_coords[kPerm[i]];
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
}
// Compute the permuted window origin
index_t permuted_window_origin[CK_TILE_MAX_RANK] = {0};
for(index_t i = 0; i < kRank; ++i)
{
permuted_window_origin[i] = permuted_tile_coords[i] * tile_sizes[i];
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]);
}
typename ODramWindowTmp::BottomTensorIndex step = {};
for(index_t i = 0; i < kRank; ++i)
{
step[i] = permuted_window_origin[i] - current_window_origin[i];
}
// Move the window
move_tile_window(o_dram_window_tmp, step);
// Permute the data within the tile if necessary
if constexpr(kTilePermute)
{
permute_tile_data(o_acc_tile);
}
// Store the tile data to the permuted location
if constexpr(kPadM || kPadN)
{
if constexpr(out_memory_data_op == memory_operation_enum::set) if constexpr(out_memory_data_op == memory_operation_enum::set)
{ {
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
buffer_store_fence();
} }
else else
{ {
if constexpr(out_memory_data_op == memory_operation_enum::set) update_tile(out_dram_window, c_out_tensor);
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
} }
else if constexpr(iAccess != num_access - 1)
{ {
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); constexpr auto step = SFC::get_forward_step(iAccess);
} move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
} }
});
} }
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -23,6 +25,26 @@ struct Default2DEpilogueProblem ...@@ -23,6 +25,26 @@ struct Default2DEpilogueProblem
static constexpr bool UseRawStore = UseRawStore_; static constexpr bool UseRawStore = UseRawStore_;
}; };
template <typename AccDataType_,
typename ODataType_,
typename CLayout_,
bool kPadM_,
bool kPadN_,
index_t kMPerXdl_,
index_t kNPerXdl_,
index_t kKPerXdl_,
bool isCTransposed_,
bool UseRawStore_ = true>
struct DefaultGemm2DEpilogueProblem
: public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
{
using CLayout = remove_cvref_t<CLayout_>;
static constexpr index_t kMPerXdl = kMPerXdl_;
static constexpr index_t kNPerXdl = kNPerXdl_;
static constexpr index_t kKPerXdl = kKPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
};
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct Default2DEpilogue struct Default2DEpilogue
{ {
...@@ -35,14 +57,13 @@ struct Default2DEpilogue ...@@ -35,14 +57,13 @@ struct Default2DEpilogue
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() { return false; }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size // TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ? // how do we fix this ?
template <typename ODramWindowTmp, template <typename ODramWindowTmp,
typename OAccTile, typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set> memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile) CK_TILE_DEVICE auto
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
{ {
// TODO: this is ugly // TODO: this is ugly
...@@ -71,4 +92,76 @@ struct Default2DEpilogue ...@@ -71,4 +92,76 @@ struct Default2DEpilogue
} }
} }
}; };
template <typename Problem_, typename Policy_ = void>
struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
using WG = WarpGemmMfmaDispatcher<ODataType,
ODataType,
AccDataType,
kMPerXdl,
kNPerXdl,
kKPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(isCTransposed)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
else
{
// In this case each thread has just a single item in Ndim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(isCTransposed)
{
// In this case each thread has just a single item in Mdim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
}
else
{
static_assert(false, "Unsupported CLayout!");
}
}
};
} // namespace ck_tile } // namespace ck_tile
...@@ -159,12 +159,8 @@ struct GemmKernel ...@@ -159,12 +159,8 @@ struct GemmKernel
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{ {
constexpr bool is_output_c_reg_transposed = if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC(); is_any_of<CDataType, fp16_t, bf16_t>::value)
if constexpr(!((GemmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed) ||
!(std::is_same_v<CDataType, fp16_t> || std::is_same_v<CDataType, bf16_t>)))
{ {
if(kargs.KBatch != 1) if(kargs.KBatch != 1)
{ {
...@@ -182,7 +178,7 @@ struct GemmKernel ...@@ -182,7 +178,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.K % GemmPipeline::VectorSizeA != 0) if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
{ {
std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl; std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
return false; return false;
...@@ -197,7 +193,7 @@ struct GemmKernel ...@@ -197,7 +193,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.M % GemmPipeline::VectorSizeA != 0) if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
{ {
std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl; std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
return false; return false;
...@@ -213,7 +209,7 @@ struct GemmKernel ...@@ -213,7 +209,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.N % GemmPipeline::VectorSizeB != 0) if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
{ {
std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl; std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
return false; return false;
...@@ -228,7 +224,7 @@ struct GemmKernel ...@@ -228,7 +224,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.K % GemmPipeline::VectorSizeB != 0) if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
{ {
std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl; std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
return false; return false;
...@@ -244,7 +240,7 @@ struct GemmKernel ...@@ -244,7 +240,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.N % GemmPipeline::VectorSizeC != 0) if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
{ {
std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl; std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
return false; return false;
...@@ -259,7 +255,7 @@ struct GemmKernel ...@@ -259,7 +255,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.M % GemmPipeline::VectorSizeC != 0) if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
{ {
std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl; std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
return false; return false;
...@@ -275,14 +271,6 @@ struct GemmKernel ...@@ -275,14 +271,6 @@ struct GemmKernel
const GemmKernelArgs& kargs, const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset) const SplitKBatchOffset& splitk_batch_offset)
{ {
// const auto idxs = TilePartitioner{}();
// const auto i_m = idxs.at(number<0>{});
// const auto i_n = idxs.at(number<1>{});
// // options
// const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
// const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// // Convert pointers to tensor views
// auto a_tensor_view = [&]() {
const auto& a_tensor_view = [&]() { const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
...@@ -290,7 +278,7 @@ struct GemmKernel ...@@ -290,7 +278,7 @@ struct GemmKernel
a_ptr, a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k), make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1), make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{}, number<GemmPipeline::GetVectorSizeA()>{},
number<1>{}); number<1>{});
} }
else else
...@@ -299,7 +287,7 @@ struct GemmKernel ...@@ -299,7 +287,7 @@ struct GemmKernel
a_ptr, a_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.M), make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(kargs.stride_A, 1), make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{}, number<GemmPipeline::GetVectorSizeA()>{},
number<1>{}); number<1>{});
} }
}(); }();
...@@ -311,7 +299,7 @@ struct GemmKernel ...@@ -311,7 +299,7 @@ struct GemmKernel
b_ptr, b_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.N), make_tuple(splitk_batch_offset.splitted_k, kargs.N),
make_tuple(kargs.stride_B, 1), make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{}, number<GemmPipeline::GetVectorSizeB()>{},
number<1>{}); number<1>{});
} }
else else
...@@ -320,7 +308,7 @@ struct GemmKernel ...@@ -320,7 +308,7 @@ struct GemmKernel
b_ptr, b_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k), make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_B, 1), make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{}, number<GemmPipeline::GetVectorSizeB()>{},
number<1>{}); number<1>{});
} }
}(); }();
...@@ -333,7 +321,7 @@ struct GemmKernel ...@@ -333,7 +321,7 @@ struct GemmKernel
c_ptr, c_ptr,
make_tuple(kargs.M, kargs.N), make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1), make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{}, number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{}); number<1>{});
} }
else else
...@@ -501,16 +489,13 @@ struct GemmKernel ...@@ -501,16 +489,13 @@ struct GemmKernel
// Run Epilogue Pipeline // Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2); auto& c_block_window = gemm_tile_windows.at(I2);
constexpr bool is_output_c_reg_transposed = if constexpr(DstInMemOp == memory_operation_enum::set ||
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC(); !(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
if constexpr((DstInMemOp == memory_operation_enum::set) || (sizeof(CDataType) > 2) || is_any_of<CDataType, fp16_t, bf16_t>::value))
(GemmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed))
{ {
EpiloguePipeline{} EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>( .template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile); c_block_window, c_block_tile, smem_ptr);
} }
} }
......
...@@ -21,6 +21,8 @@ struct GemmPipelineAgBgCrImplBase ...@@ -21,6 +21,8 @@ struct GemmPipelineAgBgCrImplBase
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
template <typename DstBlockTile, typename SrcTileWindow, typename DramTileWindowStep> template <typename DstBlockTile, typename SrcTileWindow, typename DramTileWindowStep>
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile, CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window, SrcTileWindow& dram_tile_window,
......
...@@ -20,6 +20,8 @@ struct BaseGemmPipelineAgBgCrCompV3 ...@@ -20,6 +20,8 @@ struct BaseGemmPipelineAgBgCrCompV3
static constexpr index_t PrefillStages = 1; static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1; static constexpr index_t GlobalBufferNum = 1;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
{ {
return num_loop > PrefetchStages; return num_loop > PrefetchStages;
...@@ -62,9 +64,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -62,9 +64,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Policy::template GetVectorSizeA<Problem>(); static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t VectorSizeB = Policy::template GetVectorSizeB<Problem>(); static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t VectorSizeC = Policy::template GetVectorSizeC<Problem>(); static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
...@@ -81,11 +83,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -81,11 +83,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Policy::template IsTransposeC<Problem>();
}
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase struct PipelineImpl : public PipelineImplBase
{ {
...@@ -110,9 +107,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -110,9 +107,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
constexpr index_t B_LDS_Read_Width = KPerXDL; constexpr index_t B_LDS_Read_Width = KPerXDL;
constexpr index_t A_Buffer_Load_Inst_Num = constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * VectorSizeA); MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num = constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * VectorSizeB); NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL); constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL); constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -20,6 +20,8 @@ struct BaseGemmPipelineAgBgCrMem ...@@ -20,6 +20,8 @@ struct BaseGemmPipelineAgBgCrMem
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
...@@ -113,9 +115,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -113,9 +115,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Policy::template GetVectorSizeA<Problem>(); static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t VectorSizeB = Policy::template GetVectorSizeB<Problem>(); static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t VectorSizeC = Policy::template GetVectorSizeC<Problem>(); static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
...@@ -133,11 +135,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -133,11 +135,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Policy::template IsTransposeC<Problem>();
}
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase struct PipelineImpl : public PipelineImplBase
{ {
......
...@@ -31,21 +31,21 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -31,21 +31,21 @@ struct GemmPipelineAGmemBGmemCRegV1
static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK; static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Problem::VectorSizeA; static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
static constexpr index_t VectorSizeB = Problem::VectorSizeB; static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK; static constexpr bool kPadK = Problem::kPadK;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <typename ADramBlockWindowTmp, template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp, typename BDramBlockWindowTmp,
typename AElementFunction, typename AElementFunction,
......
...@@ -16,8 +16,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -16,8 +16,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{}; static constexpr auto I2 = number<2>{};
static constexpr bool TransposeC = true;
// 3d + padding // 3d + padding
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
...@@ -383,8 +381,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -383,8 +381,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
} }
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
...@@ -397,7 +393,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -397,7 +393,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
WarpTile::at(I0), WarpTile::at(I0),
WarpTile::at(I1), WarpTile::at(I1),
WarpTile::at(I2), WarpTile::at(I2),
TransposeC>; Problem::TransposeC>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType, using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType, typename Problem::BDataType,
typename Problem::CDataType, typename Problem::CDataType,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -25,6 +25,8 @@ struct GemmPipelineAGmemBGmemCRegV2 ...@@ -25,6 +25,8 @@ struct GemmPipelineAGmemBGmemCRegV2
static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK; static constexpr index_t kKPerBlock = BlockGemmShape::kK;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{ {
return integer_divide_ceil( return integer_divide_ceil(
...@@ -36,8 +38,6 @@ struct GemmPipelineAGmemBGmemCRegV2 ...@@ -36,8 +38,6 @@ struct GemmPipelineAGmemBGmemCRegV2
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <typename ADramBlockWindowTmp, template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp, typename BDramBlockWindowTmp,
typename AElementFunction, typename AElementFunction,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -27,6 +27,8 @@ struct GemmPipelineProblemBase ...@@ -27,6 +27,8 @@ struct GemmPipelineProblemBase
using BLayout = remove_cvref_t<typename Traits::BLayout>; using BLayout = remove_cvref_t<typename Traits::BLayout>;
using CLayout = remove_cvref_t<typename Traits::CLayout>; using CLayout = remove_cvref_t<typename Traits::CLayout>;
static constexpr bool TransposeC = Traits::TransposeC;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadM = Traits::kPadM; static constexpr bool kPadM = Traits::kPadM;
...@@ -111,7 +113,6 @@ struct GemmPipelineProblemBase ...@@ -111,7 +113,6 @@ struct GemmPipelineProblemBase
return kPadK ? 1 : GetAlignmentB(); return kPadK ? 1 : GetAlignmentB();
} }
}(); }();
static constexpr index_t VectorSizeC = []() { static constexpr index_t VectorSizeC = []() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
......
...@@ -549,12 +549,6 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -549,12 +549,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Problem::TransposeC;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
......
...@@ -32,9 +32,6 @@ class TestCkTileBatchedGemm : public ::testing::Test ...@@ -32,9 +32,6 @@ class TestCkTileBatchedGemm : public ::testing::Test
constexpr bool kPadM = false; constexpr bool kPadM = false;
constexpr bool kPadN = false; constexpr bool kPadN = false;
constexpr bool kPadK = false; constexpr bool kPadK = false;
constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
...@@ -51,11 +48,6 @@ class TestCkTileBatchedGemm : public ::testing::Test ...@@ -51,11 +48,6 @@ class TestCkTileBatchedGemm : public ::testing::Test
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
using CodegenGemmShape = using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>, ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
...@@ -63,21 +55,6 @@ class TestCkTileBatchedGemm : public ::testing::Test ...@@ -63,21 +55,6 @@ class TestCkTileBatchedGemm : public ::testing::Test
using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadM,
kPadN,
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using CodegenGemmTraits = using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
...@@ -88,6 +65,20 @@ class TestCkTileBatchedGemm : public ::testing::Test ...@@ -88,6 +65,20 @@ class TestCkTileBatchedGemm : public ::testing::Test
CodegenGemmTraits>; CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
CLayout,
CodegenGemmPipeline::BlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC>>;
using Kernel = using Kernel =
ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <sstream> #include <sstream>
...@@ -65,9 +65,6 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -65,9 +65,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>; using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile:: using GemmUniversalTraits = ck_tile::
TileGemmUniversalTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, TransposeC>; TileGemmUniversalTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, TransposeC>;
...@@ -106,6 +103,20 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -106,6 +103,20 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem, ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>>; ck_tile::UniversalGemmPipelineAgBgCrPolicy>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
CLayout,
GemmPipeline::BlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args); auto kargs = Kernel::MakeKernelArgs(args);
...@@ -244,7 +255,7 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -244,7 +255,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
public: public:
std::vector<int> k_batches_; std::vector<int> k_batches_;
void SetUp() override { k_batches_ = {1}; } void SetUp() override { k_batches_ = {1, 2}; }
template <bool PadM = true, bool PadN = true, bool PadK = true> template <bool PadM = true, bool PadN = true, bool PadK = true>
void Run(const int M, void Run(const int M,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <sstream> #include <sstream>
...@@ -29,9 +29,6 @@ class TestCkTileGroupedGemm : public ::testing::Test ...@@ -29,9 +29,6 @@ class TestCkTileGroupedGemm : public ::testing::Test
static const bool kPadM = false; static const bool kPadM = false;
static const bool kPadN = false; static const bool kPadN = false;
static const bool kPadK = false; static const bool kPadK = false;
static const bool kTilePermute = false;
static const ck_tile::index_t kOutputRank = 2;
static const int kBlockPerCu = 1; static const int kBlockPerCu = 1;
static const ck_tile::index_t M_Tile = 128; static const ck_tile::index_t M_Tile = 128;
...@@ -60,26 +57,6 @@ class TestCkTileGroupedGemm : public ::testing::Test ...@@ -60,26 +57,6 @@ class TestCkTileGroupedGemm : public ::testing::Test
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
template <typename CLayout>
using GemmEpilogue =
std::conditional_t<std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>,
ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kTilePermute,
GroupedGemKernelParam::kOutputRank,
1,
0,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType,
CDataType,
GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN>>>;
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmTraits = ck_tile::TileGemmTraits<GroupedGemKernelParam::kPadM, using CodegenGemmTraits = ck_tile::TileGemmTraits<GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN, GroupedGemKernelParam::kPadN,
...@@ -100,10 +77,25 @@ class TestCkTileGroupedGemm : public ::testing::Test ...@@ -100,10 +77,25 @@ class TestCkTileGroupedGemm : public ::testing::Test
using CodegenGemmPipeline = using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>>; ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>>;
template <typename ALayout, typename BLayout, typename CLayout>
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
AccDataType,
CDataType,
CLayout,
CodegenGemmPipeline<ALayout, BLayout, CLayout>::BlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile,
CodegenPipelineProblem<ALayout, BLayout, CLayout>::TransposeC>>;
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
CodegenGemmPipeline<ALayout, BLayout, CLayout>, CodegenGemmPipeline<ALayout, BLayout, CLayout>,
GemmEpilogue<CLayout>>; GemmEpilogue<ALayout, BLayout, CLayout>>;
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs) std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment