Commit c400e5b3 authored by Adam Osewski's avatar Adam Osewski
Browse files

Introduce static encoding pattern

parent 6fe9e964
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -7,6 +7,7 @@
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
namespace ck_tile {
/**
* @brief Enumeration describing tile distribution patterns.
*
*/
enum struct tile_distribution_pattern
{
/**
* @brief Thread raked pattern.
*
*/
thread_raked,
/**
* @brief Warp raked pattern.
*
*/
warp_raked,
/**
* @brief Block raked pattern - aka linear.
*
*/
block_raked,
// TODO pattern taking into account MFMA attributes:
// block_fmha_pipeline_qx_ks_vs_custom_policy.hpp::51 MakeQDramTileDistribution()
};
struct TileDistributionEcodingPattern
{
};
/**
* @brief Class creating 2D static tile distribution with different load/store patterns.
*
* @note We always assume that Tile is YPerTile x XPerTile where X dim (rightmost)
* is contiguous and we can do vector load on this dimension.
*
* @tparam BlockSize Number of threads in a workgroup.
* @tparam YPerTile The tile size of outer/leftmost dimension.
* @tparam XPerTile The tile size of inner/rightmost dimension (contiguous).
* @tparam VecSize The vector access size.
* @tparam DistributionPattern The enumeration describing used access pattern.
*/
template <index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize,
tile_distribution_pattern DistributionPattern>
struct TileDistributionEncodingPattern2D : public TileDistributionEcodingPattern
{
};
// Thread raked
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::thread_raked>
: public TileDistributionEcodingPattern
{
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
constexpr index_t warp_size = get_warp_size();
constexpr index_t num_warps = BlockSize / get_warp_size();
constexpr index_t X1 = VecSize;
constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
// # of rows in Y dim accessed by single wavefront in one iteration
constexpr index_t Y1 = warp_size / X0;
static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
constexpr index_t Y0 = num_warps;
// YPerWarp = YPerTile / Y0;
// Y2 = YPerWarp / Y1;
constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront
static_assert(X0 * Y1 * Y0 == BlockSize, "X0 * warp_ys * Y0 must cover whole workgroup!");
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
};
// Warp raked
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::warp_raked>
: public TileDistributionEcodingPattern
{
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
constexpr index_t warp_size = get_warp_size();
constexpr index_t num_warps = BlockSize / get_warp_size();
constexpr index_t X1 = VecSize;
constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
constexpr index_t Y0 = num_warps;
static_assert(X0 * Y2 * Y0 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
constexpr index_t Y1 = YPerTile / (Y2 * Y0); // # of iters within wavefront
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
};
// Block raked
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::block_raked>
: public TileDistributionEcodingPattern
{
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
constexpr index_t warp_size = get_warp_size();
constexpr index_t num_warps = BlockSize / get_warp_size();
constexpr index_t X1 = VecSize;
constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
constexpr index_t Y1 = num_warps;
static_assert(X0 * Y2 * Y1 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
constexpr index_t Y0 = YPerTile / (Y2 * Y1);
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -17,15 +17,23 @@ struct UniversalGemmPipelineAgBgCrPolicy
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
/**
* @brief Get the maximum global memory vector load size.
*
* @tparam Problem The UniversalGemmPipelineProblem object.
* @tparam DataType The tensor data type we're considering.
* @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B).
* @return Maximum DRAM vector load size.
*/
template <typename Problem, typename DataType, index_t MNPerBlock>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorLoadSize()
CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize()
{
// TODO this does not take into accout the size of contiguous dim!
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
// Assume DataType is even!
if constexpr(elements_per_thread % (16 / sizeof(DataType)) == 0)
{
return (16 / sizeof(DataType));
......@@ -56,7 +64,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
return GetVectorLoadSize<Problem, ADataType, MPerBlock>();
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock>();
}
template <typename Problem>
......@@ -65,7 +73,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
return GetVectorLoadSize<Problem, BDataType, NPerBlock>();
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock>();
}
/**
......@@ -90,9 +98,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
using CLayout = typename Problem::CLayout;
using CWarpDstr = typename WG::CWarpDstr;
// constexpr auto c_warp_x_lengths = CWarpDstr::get_lengths();
// using c_warp_hs_lengths = typename CWarpDstrEncoding::HsLengthss;
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
......@@ -100,11 +105,11 @@ struct UniversalGemmPipelineAgBgCrPolicy
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
// static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
// c_warp_y_lengths.get(number<NDimY-1>{}));
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
else
......@@ -125,11 +130,11 @@ struct UniversalGemmPipelineAgBgCrPolicy
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
// static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
// c_warp_y_lengths.get(number<NDimY-1>{}));
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
}
......@@ -139,6 +144,26 @@ struct UniversalGemmPipelineAgBgCrPolicy
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
// TODO: this not alwyas has to be ture, sometimes we may want different KPack value.
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
// TODO: this not alwyas has to be ture, sometimes we may want different KPack value.
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
......@@ -147,7 +172,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>();
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr auto DataTypeSize = sizeof(ADataType);
constexpr auto MLdsLayer =
......@@ -198,7 +223,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>();
constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr auto DataTypeSize = sizeof(BDataType);
constexpr auto NLdsLayer =
......@@ -237,6 +262,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// make_tuple(sequence<1>{}, sequence<0>{}));
return b_lds_block_desc;
}
......@@ -276,88 +302,30 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
// Tile: MPerBlock X KPerBlock
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t M1 = GetVectorSizeA<Problem>();
constexpr index_t M0 = MPerBlock / M1;
constexpr index_t elem_per_thr = MPerBlock * KPerBlock / BlockSize;
constexpr index_t K3 = elem_per_thr / M1; // # of loads per thr
constexpr index_t KPack = GetVectorSizeA<Problem>();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * M0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * M0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
// We should take layout into account!
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
constexpr auto AccessPattern = tile_distribution_pattern::thread_raked;
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
MPerBlock,
KPerBlock,
VecLoadSize,
AccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
// Tile: KPerBlock X MPerBlock
else
{
// In RowMajor scenario we usually want to read whole KPerBlock tile dim
constexpr index_t K1 = GetVectorSizeA<Problem>();
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
// Coalesce reading for whole workgroup - workgroup raked pattern
if constexpr(get_warp_size() % (M2 * K0) == 0)
{
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
// Coalesce reading for each wavefront - wavefront raked pattern
else
{
constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
constexpr auto AccessPattern = tile_distribution_pattern::thread_raked;
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
AccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
}
......@@ -370,107 +338,47 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
// Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = GetVectorSizeB<Problem>();
constexpr index_t N0 = NPerBlock / N1;
constexpr index_t elem_per_thr = NPerBlock * KPerBlock / BlockSize;
static_assert(elem_per_thr % N1 == 0);
constexpr index_t K3 = elem_per_thr / N1;
constexpr index_t KPack = GetVectorSizeB<Problem>();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
constexpr auto AccessPattern = tile_distribution_pattern::thread_raked;
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
AccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
// Tile: NPerBlock X KPerBlock
else
{
constexpr index_t K1 = GetVectorSizeB<Problem>();
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
if constexpr(get_warp_size() % (N2 * K0) == 0)
{
constexpr index_t N1 = BlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
static_assert(N0 * N1 * N2 == NPerBlock,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
// coalesce reading for each warps
else
{
constexpr index_t N0 = BlockSize / get_warp_size();
constexpr index_t N1 = NPerBlock / (N2 * N0);
static_assert(N0 * N1 * N2 == NPerBlock,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
constexpr auto AccessPattern = tile_distribution_pattern::thread_raked;
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
NPerBlock,
KPerBlock,
VecLoadSize,
AccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M1 = GetVectorSizeA<Problem>();
constexpr index_t M0 = MPerBlock / M1;
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % M1 == 0);
constexpr index_t K3 = total_pixels / M1;
constexpr index_t kKPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>();
constexpr index_t kKPack = GetSmemPackA<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
......@@ -506,19 +414,18 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t N1 = GetVectorSizeB<Problem>();
constexpr index_t N0 = NPerBlock / N1;
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0);
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>();
constexpr index_t kKPack = GetSmemPackB<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
......
......@@ -4,7 +4,6 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment