Unverified Commit 95e722a3 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

[CK_TILE] Optimize fmha splitkv & splitkv combine kernels (#1577)

* Use smaller width for lse_accum dist tensor

* Update pipeline comment

* Fix wrong distribution for lse_accum

* Remove duplicate dim in lse_accum dist encoding

* Decide fmha splitkv combine kernel kBlockSize by kM0

* Remove assumption of MPerThread=1

* Add log<4> & log<8> specialization

* Enlarge occupancy array

* Fix vector size for small tile

* Add support for kMaxSplits=8

* Re-format gemm.hpp

* Use 16x16x16 warp gemm for fwd_splitkv

* Centralize policy code changes

* Leave fp8/bf8 tile settings unchanged
parent a285d6f9
...@@ -191,7 +191,9 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m ...@@ -191,7 +191,9 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m
template<> template<>
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
if (a.num_splits <= 16) {{ if (a.num_splits <= 8) {{
kernel_runner<3>::run(s, a);
}} else if (a.num_splits <= 16) {{
kernel_runner<4>::run(s, a); kernel_runner<4>::run(s, a);
}} else if (a.num_splits <= 32) {{ }} else if (a.num_splits <= 32) {{
kernel_runner<5>::run(s, a); kernel_runner<5>::run(s, a);
...@@ -239,7 +241,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ...@@ -239,7 +241,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a); return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}} }}
...@@ -551,10 +553,10 @@ class FmhaFwdSplitKVCombineKernel: ...@@ -551,10 +553,10 @@ class FmhaFwdSplitKVCombineKernel:
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 16, 16, 16, -1),
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 16, 16, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 16, 16, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 16, 16, 16, -1),
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
...@@ -568,16 +570,16 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: ...@@ -568,16 +570,16 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : FmhaFwdSplitKVCombineTileSize(64, 32, -1), '32' : FmhaFwdSplitKVCombineTileSize(16, 16, -1),
'64' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), '64' : FmhaFwdSplitKVCombineTileSize(32, 32, -1),
'128' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), '128' : FmhaFwdSplitKVCombineTileSize(32, 64, -1),
'256' : FmhaFwdSplitKVCombineTileSize(64, 256, -1), '256' : FmhaFwdSplitKVCombineTileSize(32, 128, -1),
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
'64' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), '64' : FmhaFwdSplitKVCombineTileSize(64, 32, -1),
'128' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), '128' : FmhaFwdSplitKVCombineTileSize(64, 64, -1),
'256' : FmhaFwdSplitKVCombineTileSize(64, 256, -1), '256' : FmhaFwdSplitKVCombineTileSize(64, 128, -1),
} }
else: else:
return None return None
......
...@@ -12,6 +12,16 @@ namespace detail { ...@@ -12,6 +12,16 @@ namespace detail {
template <index_t N> template <index_t N>
struct log2; struct log2;
template <>
struct log2<4> : std::integral_constant<index_t, 2>
{
};
template <>
struct log2<8> : std::integral_constant<index_t, 3>
{
};
template <> template <>
struct log2<16> : std::integral_constant<index_t, 4> struct log2<16> : std::integral_constant<index_t, 4>
{ {
...@@ -72,18 +82,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -72,18 +82,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline
{ {
if constexpr(kHeadDimV <= 32) if constexpr(kHeadDimV <= 32)
{ {
constexpr std::array<int, 4> occupancy{3, 3, 3, 1}; constexpr std::array occupancy{3, 3, 3, 3, 3, 1};
return occupancy[detail::log2<kMaxSplits>::value - 4]; return occupancy[detail::log2<kMaxSplits>::value - 2];
} }
else if constexpr(kHeadDimV <= 128) else if constexpr(kHeadDimV <= 128)
{ {
constexpr std::array<int, 4> occupancy{3, 3, 2, 1}; constexpr std::array occupancy{3, 3, 3, 3, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 4]; return occupancy[detail::log2<kMaxSplits>::value - 2];
} }
else if constexpr(kHeadDimV <= 256) else if constexpr(kHeadDimV <= 256)
{ {
constexpr std::array<int, 4> occupancy{2, 2, 2, 1}; constexpr std::array occupancy{2, 2, 2, 2, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 4]; return occupancy[detail::log2<kMaxSplits>::value - 2];
} }
} }
}(); }();
...@@ -138,9 +148,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -138,9 +148,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto lse_accum = make_static_distributed_tensor<LSEDataType>( auto lse_accum = make_static_distributed_tensor<LSEDataType>(
Policy::template MakeLSEaccRegTileDistribution<Problem>()); Policy::template MakeLSEaccRegTileDistribution<Problem>());
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, max(kMaxSplits, warp_size)]) // copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits])
// this will extend the distributed tensor width so that each thread in wave have data to // and fill up -INF values outside the [kM0, num_splits] region.
// reduce.
{ {
constexpr auto spans = decltype(lse_accum)::get_distributed_spans(); constexpr auto spans = decltype(lse_accum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) { sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
......
...@@ -10,11 +10,26 @@ namespace ck_tile { ...@@ -10,11 +10,26 @@ namespace ck_tile {
struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{ {
template <index_t BlockSize, index_t M, index_t N, typename DataType>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeForTile()
{
constexpr index_t PixelsPerThread = (M * N) / BlockSize;
static_assert(0 < PixelsPerThread);
constexpr index_t MaxNPerThread = 16 / sizeof(DataType);
constexpr index_t NPerThread = min(MaxNPerThread, PixelsPerThread);
return NPerThread;
}
// alignment for dram lse tile (shape=[kMaxSplits, kM0])
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE()
{ {
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; return GetVectorSizeForTile<Problem::kBlockSize,
return 16 / sizeof(LSEDataType); Problem::kMaxSplits,
Problem::kM0,
typename Problem::LSEDataType>();
} }
template <typename Problem> template <typename Problem>
...@@ -47,29 +62,31 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -47,29 +62,31 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
MakeLSEaccLdsBlockDescriptor<Problem>().get_element_space_size(); MakeLSEaccLdsBlockDescriptor<Problem>().get_element_space_size();
} }
// shape=[kMaxSplits, kM0]
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution()
{ {
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNumWarps = Problem::kNumWarps;
constexpr index_t kNPerBlock = Problem::kM0; constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t kMPerBlock = Problem::kMaxSplits; constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t NPerThread = 16 / sizeof(LSEDataType); constexpr index_t NPerThread =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr index_t NThreads = kNPerBlock / NPerThread; constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads; constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads;
constexpr index_t TotalWarps = kBlockSize / get_warp_size(); constexpr index_t MPerThread = kMPerBlock / (kNumWarps * MThreadsPerWarp);
constexpr index_t MPerThread = kMPerBlock / (TotalWarps * MThreadsPerWarp);
static_assert(NThreads * NPerThread == kNPerBlock); static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MPerThread * TotalWarps * MThreadsPerWarp == kMPerBlock); static_assert(MPerThread * kNumWarps * MThreadsPerWarp == kMPerBlock);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, TotalWarps, MThreadsPerWarp>, tuple<sequence<MPerThread, kNumWarps, MThreadsPerWarp>,
sequence<NThreads, NPerThread>>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>, tuple<sequence<1>, sequence<2, 0>>,
...@@ -77,15 +94,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -77,15 +94,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
sequence<0, 1>>{}); sequence<0, 1>>{});
} }
// 3d + padding, [kMaxSplits, kM0] // 3d + padding, shape=[kMaxSplits, kM0]
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsStoreBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsStoreBlockDescriptor()
{ {
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kMaxSplits; constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0; constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = 16 / sizeof(LSEDataType); constexpr index_t NPack =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}), make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
...@@ -103,15 +123,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -103,15 +123,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
return lse_acc_lds_block_desc; return lse_acc_lds_block_desc;
} }
// 3d + padding, [kM0, kMaxSplits] // 3d + padding, shape=[kM0, kMaxSplits]
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsBlockDescriptor()
{ {
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kMaxSplits; constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0; constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = 16 / sizeof(LSEDataType); constexpr index_t NPack =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}), make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
...@@ -134,26 +157,28 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -134,26 +157,28 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = max(Problem::kMaxSplits, get_warp_size()); constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t NThreads = get_warp_size(); constexpr index_t NThreads = 4;
constexpr index_t NPerThread = kNPerBlock / NThreads; constexpr index_t NPerThread = kNPerBlock / NThreads;
constexpr index_t MThreads = kBlockSize / NThreads; constexpr index_t MThreads = kBlockSize / NThreads;
constexpr index_t MPerThread = kMPerBlock / MThreads; constexpr index_t MPerThread = kMPerBlock / MThreads;
constexpr index_t MWarps = kBlockSize / get_warp_size();
constexpr index_t MThreadPerWarp = get_warp_size() / NThreads;
static_assert(NThreads * NPerThread == kNPerBlock); static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MThreads * MPerThread == kMPerBlock); static_assert(MWarps * MThreadPerWarp * MPerThread == kMPerBlock);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<
sequence<1>, sequence<1>,
tuple<sequence<MThreads, MPerThread>, sequence<NThreads, NPerThread>>, tuple<sequence<MWarps, MThreadPerWarp, MPerThread>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<2>>, tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0>>, tuple<sequence<0>, sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<1, 1>>{}); sequence<2, 1>>{});
} }
template <typename Problem> template <typename Problem>
......
...@@ -115,7 +115,8 @@ struct BlockFmhaSplitKVCombinePipelineProblem ...@@ -115,7 +115,8 @@ struct BlockFmhaSplitKVCombinePipelineProblem
using ODataType = remove_cvref_t<ODataType_>; using ODataType = remove_cvref_t<ODataType_>;
using Traits = remove_cvref_t<Traits_>; using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = 256; static constexpr index_t kNumWarps = kM0_ / (get_warp_size() / 4);
static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kHeadDimV = HeadDimV_; static constexpr index_t kHeadDimV = HeadDimV_;
......
...@@ -88,22 +88,33 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -88,22 +88,33 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
typename Problem::BlockFmhaShape::Gemm0WarpTile>>; typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static_assert(WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> && std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>) std::is_same_v<typename Problem::SaccDataType, float>)
{ {
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else // WarpGemmM == 16
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
} }
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> && else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> && std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>) std::is_same_v<typename Problem::SaccDataType, float>)
{ {
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else // WarpGemmM == 16
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
} }
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> && else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> && std::is_same_v<typename Problem::KDataType, fp8_t> &&
std::is_same_v<typename Problem::SaccDataType, float>) std::is_same_v<typename Problem::SaccDataType, float>)
{ {
static_assert(WarpGemmM == 32);
// TODO: hard coded here. Otherwise, it may incorrect result // TODO: hard coded here. Otherwise, it may incorrect result
constexpr index_t swizzle_factor = 4; constexpr index_t swizzle_factor = 4;
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution< return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
......
...@@ -23,12 +23,12 @@ ...@@ -23,12 +23,12 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment