Unverified Commit 3c93d3c4 authored by Mateusz Ozga's avatar Mateusz Ozga Committed by GitHub
Browse files

CK-Tile Grouped GEMM refactor and post PR fixes (#1756)

* Grouped gemm simple code refactor

* Offset invoker

* Invoke generic Run, and replace name of parrtitioner variable

* Tests fix type

* Removed namespaces

* Add template param to avoid implicit cast

* Remove generic function

* Constant value

* underline enum to int16_t

* Generalize partitioner function

* Remove whitespaces

* Rename function

* Using support

* Clang-format

* Clang-format

* Fn-partitioner description fn

* Typo

* Typo 2

* Better description

* Better description

* Refactor after review

* Use ctr instead of set fn

* Inovke ctr and typo

* Comments

* Remove unnecessary comment

* Review, remove modulo
parent e7dce4d2
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
...@@ -49,7 +49,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -49,7 +49,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t< using GemmEpilogue = std::conditional_t<
CShuffleEpilogue, CShuffleEpilogue,
...@@ -61,8 +61,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -61,8 +61,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
kOutputRank, kOutputRank,
1, 1,
0, 0,
TilePartitioner::kM, TilePartitioner::MPerBlock,
TilePartitioner::kN>>, TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
......
...@@ -56,7 +56,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -56,7 +56,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>, ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>; using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue< using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
...@@ -51,7 +51,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -51,7 +51,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t< using GemmEpilogue = std::conditional_t<
CShuffleEpilogue, CShuffleEpilogue,
...@@ -63,8 +63,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -63,8 +63,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
kOutputRank, kOutputRank,
1, 1,
0, 0,
TilePartitioner::kM, TilePartitioner::MPerBlock,
TilePartitioner::kN>>, TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "grouped_gemm.hpp" #include "grouped_gemm.hpp"
#include "utils.hpp"
namespace { namespace {
...@@ -102,7 +101,7 @@ using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, ...@@ -102,7 +101,7 @@ using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
GemmEpilogue<CLayout>>; GemmEpilogue<CLayout>>;
}; // namespace }; // namespace
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs) std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{ {
return ::Kernel<std::nullptr_t, std::nullptr_t, std::nullptr_t>::GetWorkSpaceSize(gemm_descs); return ::Kernel<std::nullptr_t, std::nullptr_t, std::nullptr_t>::GetWorkSpaceSize(gemm_descs);
} }
......
...@@ -52,8 +52,8 @@ auto create_args(int argc, char* argv[]) ...@@ -52,8 +52,8 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
} }
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs); std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs);
float grouped_gemm_calc(const std::vector<grouped_gemm_kargs>& gemm_descs, float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s, const ck_tile::stream_config& s,
void* p_workspace_); void* p_workspace_);
...@@ -31,7 +31,7 @@ float invoke_gemm(int n_warmup, ...@@ -31,7 +31,7 @@ float invoke_gemm(int n_warmup,
{ {
ck_tile::DeviceMem gemm_workspace; ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(GetWorkspaceSize(args)); gemm_workspace.Realloc(get_workspace_size(args));
float ave_time = grouped_gemm<ALayout, BLayout, CLayout>( float ave_time = grouped_gemm<ALayout, BLayout, CLayout>(
args, args,
...@@ -128,16 +128,16 @@ int run_grouped_gemm_example_with_layouts(int argc, ...@@ -128,16 +128,16 @@ int run_grouped_gemm_example_with_layouts(int argc,
const ck_tile::index_t N = Ns[i]; const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i]; const ck_tile::index_t K = Ks[i];
stride_As[i] = f_get_default_stride(M, N, stride_As[i], a_layout); stride_As[i] = ck_tile::get_default_stride(M, N, stride_As[i], a_layout);
stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], b_layout); stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], b_layout);
stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{}); stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], CLayout{});
a_m_k_tensors.push_back( a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
ck_tile::HostTensor<ADataType>(f_host_tensor_descriptor(M, K, stride_As[i], a_layout))); ck_tile::host_tensor_descriptor(M, K, stride_As[i], a_layout)));
b_k_n_tensors.push_back( b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
ck_tile::HostTensor<BDataType>(f_host_tensor_descriptor(K, N, stride_Bs[i], b_layout))); ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], b_layout)));
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>( c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{}))); ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], CLayout{})));
std::cout << "gemm[" << i << "]" std::cout << "gemm[" << i << "]"
<< " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc
...@@ -178,7 +178,7 @@ int run_grouped_gemm_example_with_layouts(int argc, ...@@ -178,7 +178,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
for(int i = 0; i < group_count; ++i) for(int i = 0; i < group_count; ++i)
{ {
ck_tile::HostTensor<CDataType> c_m_n_host_ref( ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{})); ck_tile::host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{}));
c_m_n_host_ref.SetZero(); c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>( ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename TLayout>
constexpr auto
f_host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
{
using namespace ck_tile::literals;
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
}
}
template <typename TLayout>
constexpr auto
f_get_default_stride(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
{
if(stride == 0)
{
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
}
...@@ -54,7 +54,6 @@ ...@@ -54,7 +54,6 @@
#include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp" #include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp" #include "ck_tile/core/utility/functional_with_tuple.hpp"
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -12,18 +12,37 @@ ...@@ -12,18 +12,37 @@
namespace ck_tile { namespace ck_tile {
enum struct address_space_enum template <typename, bool>
struct safe_underlying_type;
template <typename T>
struct safe_underlying_type<T, true>
{
using type = std::underlying_type_t<T>;
};
template <typename T>
struct safe_underlying_type<T, false>
{
using type = void;
};
template <typename T>
using safe_underlying_type_t = typename safe_underlying_type<T, std::is_enum<T>::value>::type;
enum struct address_space_enum : std::uint16_t
{ {
generic, generic = 0,
global, global,
lds, lds,
sgpr, sgpr,
vgpr, constant,
vgpr
}; };
enum struct memory_operation_enum enum struct memory_operation_enum : std::uint16_t
{ {
set, set = 0,
atomic_add, atomic_add,
atomic_max, atomic_max,
add add
...@@ -109,4 +128,30 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0) ...@@ -109,4 +128,30 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0)
#endif #endif
} }
#define CK_CONSTANT_ADDRESS_SPACE \
__attribute__((address_space( \
static_cast<safe_underlying_type_t<address_space_enum>>(address_space_enum::constant))))
template <typename T>
__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T*)(p); // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
template <typename T>
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
namespace ck_tile {
#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4)))
template <typename T>
__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
template <typename T>
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -678,4 +678,37 @@ struct HostTensor ...@@ -678,4 +678,37 @@ struct HostTensor
Descriptor mDesc; Descriptor mDesc;
Data mData; Data mData;
}; };
template <typename TLayout>
auto host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
{
using namespace ck_tile::literals;
if constexpr(std::is_same_v<decltype(layout), tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
}
template <typename TLayout>
auto get_default_stride(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
{
if(stride == 0)
{
if constexpr(std::is_same_v<decltype(layout), tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
}
} // namespace ck_tile } // namespace ck_tile
...@@ -101,9 +101,12 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -101,9 +101,12 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}(); const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y);
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z / kargs.KBatch); const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const auto i_k = __builtin_amdgcn_readfirstlane(blockIdx.z - i_batch * kargs.KBatch); const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z / kargs.KBatch);
const auto i_k = __builtin_amdgcn_readfirstlane(blockIdx.z - i_batch * kargs.KBatch);
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_k); const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_k);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -174,7 +174,7 @@ struct GemmKernel ...@@ -174,7 +174,7 @@ struct GemmKernel
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false) if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false)
{ {
return false; return false;
} }
...@@ -185,7 +185,7 @@ struct GemmKernel ...@@ -185,7 +185,7 @@ struct GemmKernel
} }
else else
{ {
if(kargs.M % TilePartitioner::kM != 0 && GemmPipeline::kPadM == false) if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{ {
return false; return false;
} }
...@@ -197,7 +197,7 @@ struct GemmKernel ...@@ -197,7 +197,7 @@ struct GemmKernel
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{ {
if(kargs.N % TilePartitioner::kN != 0 && GemmPipeline::kPadN == false) if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{ {
return false; return false;
} }
...@@ -208,7 +208,7 @@ struct GemmKernel ...@@ -208,7 +208,7 @@ struct GemmKernel
} }
else else
{ {
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false) if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false)
{ {
return false; return false;
} }
...@@ -220,7 +220,7 @@ struct GemmKernel ...@@ -220,7 +220,7 @@ struct GemmKernel
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
if(kargs.N % TilePartitioner::kN != 0 && GemmPipeline::kPadN == false) if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{ {
return false; return false;
} }
...@@ -231,7 +231,7 @@ struct GemmKernel ...@@ -231,7 +231,7 @@ struct GemmKernel
} }
else else
{ {
if(kargs.M % TilePartitioner::kM != 0 && GemmPipeline::kPadM == false) if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{ {
return false; return false;
} }
...@@ -323,17 +323,17 @@ struct GemmKernel ...@@ -323,17 +323,17 @@ struct GemmKernel
const auto& a_tensor_view = views.at(I0); const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
return pad_tensor_view( return pad_tensor_view(a_tensor_view,
a_tensor_view, make_tuple(number<TilePartitioner::MPerBlock>{},
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{}); sequence<false, GemmPipeline::kPadK>{});
} }
else else
{ {
return pad_tensor_view( return pad_tensor_view(a_tensor_view,
a_tensor_view, make_tuple(number<TilePartitioner::MPerBlock>{},
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), number<TilePartitioner::KPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{}); sequence<GemmPipeline::kPadM, false>{});
} }
}(); }();
...@@ -341,17 +341,17 @@ struct GemmKernel ...@@ -341,17 +341,17 @@ struct GemmKernel
const auto& b_tensor_view = views.at(I1); const auto& b_tensor_view = views.at(I1);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{ {
return pad_tensor_view( return pad_tensor_view(b_tensor_view,
b_tensor_view, make_tuple(number<TilePartitioner::NPerBlock>{},
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{}); sequence<false, GemmPipeline::kPadK>{});
} }
else else
{ {
return pad_tensor_view( return pad_tensor_view(b_tensor_view,
b_tensor_view, make_tuple(number<TilePartitioner::NPerBlock>{},
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), number<TilePartitioner::KPerBlock>{}),
sequence<GemmPipeline::kPadN, false>{}); sequence<GemmPipeline::kPadN, false>{});
} }
}(); }();
...@@ -359,17 +359,17 @@ struct GemmKernel ...@@ -359,17 +359,17 @@ struct GemmKernel
const auto& c_tensor_view = views.at(I2); const auto& c_tensor_view = views.at(I2);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
return pad_tensor_view( return pad_tensor_view(c_tensor_view,
c_tensor_view, make_tuple(number<TilePartitioner::MPerBlock>{},
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{}); sequence<false, GemmPipeline::kPadN>{});
} }
else else
{ {
return pad_tensor_view( return pad_tensor_view(c_tensor_view,
c_tensor_view, make_tuple(number<TilePartitioner::MPerBlock>{},
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), number<TilePartitioner::NPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{}); sequence<GemmPipeline::kPadM, false>{});
} }
}(); }();
...@@ -383,19 +383,19 @@ struct GemmKernel ...@@ -383,19 +383,19 @@ struct GemmKernel
const auto& a_pad_view = views.at(I0); const auto& a_pad_view = views.at(I0);
const auto& a_block_window = make_tile_window( const auto& a_block_window = make_tile_window(
a_pad_view, a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_m, 0}); {i_m, 0});
const auto& b_pad_view = views.at(I1); const auto& b_pad_view = views.at(I1);
const auto& b_block_window = make_tile_window( const auto& b_block_window = make_tile_window(
b_pad_view, b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_n, 0}); {i_n, 0});
const auto& c_pad_view = views.at(I2); const auto& c_pad_view = views.at(I2);
auto c_block_window = make_tile_window( auto c_block_window = make_tile_window(
c_pad_view, c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n}); {i_m, i_n});
return make_tuple(a_block_window, b_block_window, c_block_window); return make_tuple(a_block_window, b_block_window, c_block_window);
...@@ -426,7 +426,7 @@ struct GemmKernel ...@@ -426,7 +426,7 @@ struct GemmKernel
// Create Gemm tensor views, pad views and tile windows // Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple = const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<DstInMemOp>(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); MakeGemmTensorViews<DstInMemOp>(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
;
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
...@@ -456,7 +456,10 @@ struct GemmKernel ...@@ -456,7 +456,10 @@ struct GemmKernel
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}(); const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const SplitKBatchOffset splitk_batch_offset(kargs); const SplitKBatchOffset splitk_batch_offset(kargs);
// options // options
const ADataType* a_ptr = const ADataType* a_ptr =
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
namespace ck_tile { namespace ck_tile {
template <typename BlockGemmShape_>
struct GemmTilePartitioner /** @brief Struct representing 2D block index mapping into 3D output tile space. */
template <typename BlockGemmShapeType>
struct GemmTile2DPartitioner
{ {
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShapeType>;
static constexpr index_t kM = BlockGemmShape::kM; static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t kN = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t kK = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t batch_size) /** @brief Returns 3D grid size. */
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t batch_size) noexcept(
noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
{ {
index_t GridDimX = (M + kM - 1) / kM; const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
index_t GridDimY = (N + kN - 1) / kN; const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
index_t GridDimZ = batch_size; const index_t GridDimZ = batch_size;
return dim3(GridDimX, GridDimY, GridDimZ); return dim3(GridDimX, GridDimY, GridDimZ);
} }
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) /**
* @brief Returns the number of loops.
* @param [in] K is dimension
*/
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) noexcept -> index_t
{ {
return integer_divide_ceil(K, kK); return integer_divide_ceil(K, KPerBlock);
} }
CK_TILE_DEVICE auto operator()() /**
* @brief The function returns 2D output tile space.
* @param [in] blockIdx is blockIdx.x
* @param [in] blockIdy is blockIdx.y
* @return Returns the output tile indexes.
*/
CK_TILE_DEVICE static constexpr auto GetOutputTileIndex(index_t blockIdx,
index_t blockIdy) noexcept
-> const tuple<index_t, index_t>
{ {
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kM); const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx);
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kN); const index_t iN = __builtin_amdgcn_readfirstlane(blockIdy);
return make_tuple(iM, iN); return make_tuple(iM, iN);
} }
}; };
template <typename BlockGemmShape_> /**
* @brief Struct representing 1D block index mapping into 2D output tile space.
*/
template <typename BlockGemmShapeType>
struct GemmTile1DPartitioner struct GemmTile1DPartitioner
{ {
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShapeType>;
static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N) /** @brief delete default ctr with no any object */
constexpr GemmTile1DPartitioner() noexcept = delete;
/** @brief constructs an object that does contain a N value. */
constexpr GemmTile1DPartitioner(index_t N) noexcept { N_ = N; }
/** @brief Returns 1D grid size. */
CK_TILE_HOST static constexpr auto
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
{ {
index_t GridDimX = (M + MPerBlock - 1) / MPerBlock; const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
index_t GridDimY = (N + NPerBlock - 1) / NPerBlock; const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
return dim3(GridDimX * GridDimY, 1, 1); return dim3(GridDimX * GridDimY, 1, 1);
} }
CK_TILE_HOST_DEVICE static constexpr auto GetNBlock(index_t N) /**
* @brief Returns the number of blocks in N.
* @param [in] N is dimension
*/
CK_TILE_HOST_DEVICE static constexpr auto GetNBlock(index_t N) noexcept -> index_t
{ {
return integer_divide_ceil(N, NPerBlock); return integer_divide_ceil(N, NPerBlock);
} }
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) /**
* @brief Returns the number of loops.
* @param [in] K is dimension
*/
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) noexcept -> index_t
{ {
return integer_divide_ceil(K, KPerBlock); return integer_divide_ceil(K, KPerBlock);
} }
CK_TILE_DEVICE auto operator()(index_t blockOffset, index_t NBlockSize) /**
* @brief The function returns 2D output tile space.
* @param [in] blockIdx is blockIdx.x - block_start.
* */
CK_TILE_DEVICE static constexpr auto GetOutputTileIndex(index_t blockIdx) noexcept
-> const tuple<index_t, index_t>
{
const index_t NBlock = GetNBlock(N_);
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlock);
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - (iM)*NBlock);
return make_tuple(iM, iN);
}
private:
CK_TILE_DEVICE static index_t N_;
};
/**
* @brief `GemmTile1DPartitioner::GetOutputTileIndex`'s std::false specialization,
* checking expression validity in-place for ill-formed.
*/
template <typename, typename = void>
struct HasFnOneArgImpl : std::false_type
{
};
/**
* @brief `GemmTile1DPartitioner::GetOutputTileIndex`'s std::true specialization,
* checking expression validity in-place for well-formed.
* @note: `1` - a constant value indicating the number of parameters in the function.
*/
template <typename T>
struct HasFnOneArgImpl<T, std::void_t<decltype(std::declval<T>().GetOutputTileIndex(1))>>
: std::true_type
{
};
/**
* @brief Struct used to calculate offseted tile indexes.
* @note: The struct supports the 1D-Partitioner mechanism,
* enable-if `GetOutputTileIndex`-fn is std::true_type when `GetOutputTileIndex`-fn is well-formed,
* otherwise std::false_type.
*/
template <typename PartitionerFn,
typename = typename std::enable_if_t<HasFnOneArgImpl<PartitionerFn>{}>>
struct OffsettedTile1DPartitioner
{
/**
* @brief The function subtracts the block's start (offset) from 1D raw-indexes.
* @param [in] block_start is `blockIdx.x - block_start`.
* @return Returns a `tuple` [Im, In] shifted index, used to shift 1d-tile index.
*/
[[nodiscard]] CK_TILE_DEVICE static constexpr auto GetOffsetedTileIndex(index_t block_start,
index_t N) noexcept
-> const tuple<index_t, index_t>
{ {
index_t iM = __builtin_amdgcn_readfirstlane((blockIdx.x - blockOffset) / const auto [iM, iN] = PartitionerFn(N).GetOutputTileIndex(blockIdx.x - block_start);
GetNBlock(NBlockSize) * MPerBlock);
index_t iN = __builtin_amdgcn_readfirstlane((blockIdx.x - blockOffset) %
GetNBlock(NBlockSize) * NPerBlock);
return make_tuple(iM, iN); return make_tuple(iM, iN);
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream>
#include <string>
#include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/literals.hpp" #include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
namespace ck_tile { namespace ck_tile {
struct GroupedGemmHostArgs struct GroupedGemmHostArgs : public ck_tile::GemmHostArgs
{ {
const void* a_ptr; CK_TILE_HOST GroupedGemmHostArgs() noexcept = default;
const void* b_ptr; CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_,
void* c_ptr; const void* b_ptr_,
index_t M; void* c_ptr_,
index_t N; ck_tile::index_t M_,
index_t K; ck_tile::index_t N_,
index_t stride_A; ck_tile::index_t K_,
index_t stride_B; ck_tile::index_t stride_A_,
index_t stride_C; ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_)
: GemmHostArgs(a_ptr_, b_ptr_, c_ptr_, KBatch, M_, N_, K_, stride_A_, stride_B_, stride_C_)
{
}
private:
static constexpr index_t KBatch = 1;
}; };
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_> template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GroupedGemmKernel struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
{ {
using TilePartitioner = remove_cvref_t<TilePartitioner_>; using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>; using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>; using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>; using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>; using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>; using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>; using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>; using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>; using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using GemmKernelArgs = typename Base::GemmKernelArgs;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr index_t KBatch = 1;
struct GemmTransKernelArg struct GemmTransKernelArg
{ {
GroupedGemmHostArgs group_karg; GemmKernelArgs group_karg;
ck_tile::index_t block_start; ck_tile::index_t block_start;
ck_tile::index_t block_end; ck_tile::index_t block_end;
GemmTransKernelArg() = default; GemmTransKernelArg() = default;
GemmTransKernelArg(GroupedGemmHostArgs&& karg, index_t bl_start, index_t bl_end) GemmTransKernelArg(GemmKernelArgs&& karg, index_t bl_start, index_t bl_end)
: group_karg{karg}, block_start{bl_start}, block_end{bl_end} : group_karg{karg}, block_start{bl_start}, block_end{bl_end}
{ {
} }
}; };
__host__ static size_t GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs) __host__ static auto GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
-> std::size_t
{ {
return gemm_descs.size() * sizeof(GemmTransKernelArg); return gemm_descs.size() * sizeof(GemmTransKernelArg);
} }
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } __host__ static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); }
using Hargs = GroupedGemmHostArgs;
__host__ static constexpr auto GridSize(const std::vector<Hargs>& gemm_descs) __host__ static constexpr auto GridSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
{ {
index_t grid_size = 0; index_t grid_size = 0;
for(const auto& it_desc : gemm_descs) for(const auto& it_desc : gemm_descs)
...@@ -77,7 +84,8 @@ struct GroupedGemmKernel ...@@ -77,7 +84,8 @@ struct GroupedGemmKernel
return dim3(grid_size, 1, 1); return dim3(grid_size, 1, 1);
} }
CK_TILE_HOST static auto MakeKargs(const std::vector<Hargs>& gemm_descs) CK_TILE_HOST static auto MakeKargs(const std::vector<GroupedGemmHostArgs>& gemm_descs)
-> std::vector<GemmTransKernelArg>
{ {
std::vector<GemmTransKernelArg> gemm_kernel_args_; std::vector<GemmTransKernelArg> gemm_kernel_args_;
index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size()); index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
...@@ -100,22 +108,23 @@ struct GroupedGemmKernel ...@@ -100,22 +108,23 @@ struct GroupedGemmKernel
const index_t stride_c = gemm_descs[i].stride_C; const index_t stride_c = gemm_descs[i].stride_C;
const auto dim3 = TilePartitioner::GridSize(M, N); const auto dim3 = TilePartitioner::GridSize(M, N);
const index_t grid_size_grp = dim3.x * 1 * 1; const index_t grid_size_grp = dim3.x;
const index_t block_start = grid_size; const index_t block_start = grid_size;
const index_t block_end = grid_size + grid_size_grp; const index_t block_end = grid_size + grid_size_grp;
grid_size += grid_size_grp; grid_size += grid_size_grp;
auto karg = GroupedGemmHostArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr), auto karg = GemmKernelArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
type_convert<const BDataType*>(gemm_descs[i].b_ptr), type_convert<const BDataType*>(gemm_descs[i].b_ptr),
type_convert<CDataType*>(gemm_descs[i].c_ptr), type_convert<CDataType*>(gemm_descs[i].c_ptr),
M, M,
N, N,
K, K,
stride_a, stride_a,
stride_b, stride_b,
stride_c}; stride_c,
KBatch};
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
} }
...@@ -123,162 +132,34 @@ struct GroupedGemmKernel ...@@ -123,162 +132,34 @@ struct GroupedGemmKernel
return gemm_kernel_args_; return gemm_kernel_args_;
} }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
{ {
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
} }
CK_TILE_DEVICE void Run(const Hargs& kargs, const index_t block_start) const CK_TILE_DEVICE void Run(const GemmTransKernelArg& kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}(block_start, kargs.N); const auto [iM, iN] =
// options OffsetTile1DPartitioner::GetOffsetedTileIndex(kargs.block_start, kargs.group_karg.N);
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views
auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_start,
make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_start,
make_tuple(kargs.M, kargs.K),
make_tuple(1, kargs.stride_A),
number<1>{},
number<1>{});
}
}();
auto b_tensor_view = [&]() { const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>) const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
{
return make_naive_tensor_view<address_space_enum::global>(
b_start,
make_tuple(kargs.N, kargs.K),
make_tuple(1, kargs.stride_B),
number<1>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_start,
make_tuple(kargs.N, kargs.K),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{},
number<1>{});
}
}();
auto a_pad_view = [&]() { const typename Base::SplitKBatchOffset splitk_batch_offset(kargs.group_karg, blockIdx.z);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
// clang-format on
auto a_block_window = make_tile_window( const ADataType* a_ptr = static_cast<const ADataType*>(kargs.group_karg.a_ptr);
a_pad_view, const BDataType* b_ptr = static_cast<const BDataType*>(kargs.group_karg.b_ptr);
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}), CDataType* c_ptr = static_cast<CDataType*>(kargs.group_karg.c_ptr);
{i_m, 0});
auto b_pad_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<GemmPipeline::kPadN, false>{});
}
}();
auto b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K); this->RunGemm(
a_ptr, b_ptr, c_ptr, smem_ptr, kargs.group_karg, splitk_batch_offset, i_m, i_n);
// Run GEMM cooperatively by whole wokrgroup.
auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
auto c_pad_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
auto CBlockWindow_pad = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, c_block_tile);
} }
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
int group_count) const index_t group_count) const
{ {
const index_t block_id = ck_tile::get_block_1d_id(); const index_t block_id = ck_tile::get_block_1d_id();
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>( const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
...@@ -286,7 +167,7 @@ struct GroupedGemmKernel ...@@ -286,7 +167,7 @@ struct GroupedGemmKernel
index_t left = 0; index_t left = 0;
index_t right = group_count; index_t right = group_count;
index_t group_id = index_t((left + right) / 2); index_t group_id = index_t((left + right) >> 1);
while((!(block_id >= gemm_desc_ptr[group_id].block_start && while((!(block_id >= gemm_desc_ptr[group_id].block_start &&
block_id < gemm_desc_ptr[group_id].block_end)) && block_id < gemm_desc_ptr[group_id].block_end)) &&
...@@ -300,10 +181,10 @@ struct GroupedGemmKernel ...@@ -300,10 +181,10 @@ struct GroupedGemmKernel
{ {
left = group_id; left = group_id;
} }
group_id = index_t((left + right) / 2); group_id = index_t((left + right) >> 1);
} }
Run(gemm_desc_ptr[group_id].group_karg, gemm_desc_ptr[group_id].block_start); Run(gemm_desc_ptr[group_id]);
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <sstream> #include <sstream>
...@@ -61,7 +61,7 @@ class TestCkTileBatchedGemm : public ::testing::Test ...@@ -61,7 +61,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t< using GemmEpilogue = std::conditional_t<
CShuffleEpilogue, CShuffleEpilogue,
...@@ -73,8 +73,8 @@ class TestCkTileBatchedGemm : public ::testing::Test ...@@ -73,8 +73,8 @@ class TestCkTileBatchedGemm : public ::testing::Test
kOutputRank, kOutputRank,
1, 1,
0, 0,
TilePartitioner::kM, TilePartitioner::MPerBlock,
TilePartitioner::kN>>, TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
......
...@@ -59,7 +59,7 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -59,7 +59,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>, ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>; using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue< using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment