Commit 22ab193c authored by carlushuang's avatar carlushuang
Browse files

add ck_tile for matrix_core swizzle kernel

parent b2e95e21
# generate a list of kernels, but not actually emit files at config stage
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api fwd,fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
)
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
)
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
# as current cmake list, otherwise will not figure out the dependency properly
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS)
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api fwd,fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR}
)
add_custom_command(
OUTPUT ${FMHA_BWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
)
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding example ${EXAMPLE_FMHA_FWD}")
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})
set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding example ${EXAMPLE_FMHA_BWD}")
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp)
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
# NOTE: this is dangerous since will change the whole kernel to flush denormals
# WIP with compiler team for an exp2 intrinsic..., then remove this
if(NOT DEFINED FMHA_FWD_FAST_EXP2)
set(FMHA_FWD_FAST_EXP2 true)
endif()
set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS)
set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
# ... because they are auto-generated
if(FMHA_FWD_FAST_EXP2)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif()
# Allow comparing floating points directly in order to check sentinel values
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal)
target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include <string>
#include <type_traits>
//
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, top_k=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
// number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_tokens_post_padded : top_k * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6,
// 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|-
// exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *,
// c, f, i, o]
//
// * length is max_tokens_post_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_tokens_post_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
//
// [indexing implementation-2]
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
// number)
//
// we generate original rol/col id as
// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]]
// let x be one element of above, we can get:
// tpok_row_id(token_id) = x % num_tokens(5)
// tpok_col_id(expert_Id) = x / num_tokens
// topk_row_id/col_id can be used to access original topk_ids/topk_weight
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// we can get permuted_rc_ids:
// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]]
//
//
//
//
namespace ck_tile {
// This is scatter/gather b2b group-gemm
template <typename TilePartitioner_, typename FusedMoePipeline_, typename EpiloguePipeline_>
struct FusedMoeKernel
{
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FusedMoePipeline = ck_tile::remove_cvref_t<FusedMoePipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FusedMoePipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FusedMoePipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
static constexpr ck_tile::index_t kBlockPerCuInput = FusedMoePipeline::Problem::kBlockPerCu;
using ADataType = ck_tile::remove_cvref_t<typename FusedMoePipeline::ADataType>;
using GDataType = ck_tile::remove_cvref_t<typename FusedMoePipeline::GDataType>;
using UDataType = ck_tile::remove_cvref_t<typename FusedMoePipeline::UDataType>;
using DDataType = ck_tile::remove_cvref_t<typename FusedMoePipeline::DDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FusedMoePipeline::ODataType>;
using AccDataType = ck_tile::remove_cvref_t<typename FusedMoePipeline::AccDataType>;
using ScaleDataType = ck_tile::remove_cvref_t<typename FusedMoePipeline::ScaleDataType>;
using DLayout = ck_tile::remove_cvref_t<typename FusedMoePipeline::DLayout>;
using FusedMoeTileShape = ck_tile::remove_cvref_t<typename FusedMoePipeline::FusedMoeTileShape>;
static constexpr bool kPadDimSize = FusedMoePipeline::kPadDimSize;
static constexpr bool kPadHiddenSize = FusedMoePipeline::kPadHiddenSize;
static constexpr bool kPadSeqLenQ = FusedMoePipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FusedMoePipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FusedMoePipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FusedMoePipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FusedMoePipeline::BiasEnum;
static constexpr bool kStoreLSE = FusedMoePipeline::kStoreLSE;
static constexpr bool kHasDropout = FusedMoePipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant = FusedMoePipeline::Problem::kDoFp8StaticQuant;
using FmhaMask = ck_tile::remove_cvref_t<typename FusedMoePipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
CK_TILE_HOST static std::string GetName()
{
// sync with generate.py
// clang-format off
using bfs = typename FusedMoePipeline::BlockFmhaShape;
using gbr = typename bfs::Gemm0BlockWarps;
using gwt = typename bfs::Gemm0WarpTile;
#define _SS_ std::string
#define _TS_ std::to_string
auto pn = [&] () {
std::string n;
if (kPadSeqLenQ) n += "s";
if (kPadSeqLenK) n += "sk";
if (kPadHeadDimQ) n += "d";
if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<ADataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FusedMoePipeline::name) + "_" +
"v" + (std::is_same_v<DLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_
#undef _TS_
// clang-format on
}
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
// arg
struct FusedMoeEmptyKargs
{
};
// tensors:
// 1. act (A): input feature map
// 2. gate (G): B matrix for first gemm, output will do activation(Silu)
// 3. up (U): B matrix for first gemm
// 4. down (D): B matrix for second gemm
struct FusedMoeCommonKargs
{
const void* a_ptr;
const void* g_ptr;
const void* u_ptr;
const void* d_ptr;
// const void* w_ptr; //topk-weight
void* o_ptr;
const void* sorted_token_ids_ptr;
const void* sorted_weight_ptr;
const void* sorted_expert_ids_ptr;
// const void* num_tokens_post_padded_ptr;
const void* num_sorted_tiles_ptr;
ck_tile::index_t dim_size;
ck_tile::index_t hidden_size;
ck_tile::index_t num_tokens; // input number of tokens for current iteration
ck_tile::index_t num_experts; // number of groups
// ck_tile::index_t top_k; // need this?
ck_tile::index_t stride_a;
ck_tile::index_t stride_g;
ck_tile::index_t stride_u;
ck_tile::index_t stride_d;
ck_tile::index_t stride_o;
ck_tile::index_t stride_g_expert;
ck_tile::index_t stride_u_expert;
ck_tile::index_t stride_d_expert;
};
using Kargs = FusedMoeCommonKargs; // std::conditional_t<kIsGroupMode, FusedMoeGroupModeKargs,
// FusedMoeBatchModeKargs>;
// host args are used inside host API
// and should be POD data structure
struct FusedMoeCommonHargs
{
const void* a_ptr;
const void* g_ptr;
const void* u_ptr;
const void* d_ptr;
// const void* w_ptr; //topk-weight
void* o_ptr;
const void* sorted_token_ids_ptr;
const void* sorted_weight_ptr;
const void* sorted_expert_ids_ptr;
// const void* num_tokens_post_padded_ptr;
const void* num_sorted_tiles_ptr;
ck_tile::index_t dim_size;
ck_tile::index_t hidden_size;
ck_tile::index_t num_tokens; // input number of tokens for current iteration
ck_tile::index_t num_experts; // number of groups
// ck_tile::index_t top_k; // need this?
ck_tile::index_t stride_a;
ck_tile::index_t stride_g;
ck_tile::index_t stride_u;
ck_tile::index_t stride_d;
ck_tile::index_t stride_o;
ck_tile::index_t stride_g_expert;
ck_tile::index_t stride_u_expert;
ck_tile::index_t stride_d_expert;
};
using Hargs = FusedMoeCommonHargs;
CK_TILE_HOST static constexpr ToKargs(const Hargs hargs) { return kargs; }
CK_TILE_HOST static constexpr auto GridSize(index_t num_cu, index_t blocks_per_cu)
{
return TilePartitioner::GridSize(num_cu, blocks_per_cu);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(FusedMoePipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
ck_tile::index_t num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const ck_tile::index_t*>(kargs.num_sorted_tiles_ptr));
ck_tile::index_t tile_id = __builtin_amdgcn_readfirstlane(blockIdx.x;);
// persistent loop
while(true)
{
const auto [sorted_tile_id, hidden_tile_id] =
TilePartitioner{}(tile_id, num_sorted_tiles, kargs.hidden_size);
if(sorted_tile_id >= num_sorted_tiles)
return;
ck_tile::index_t expert_id =
__builtin_amdgcn_readfirstlane(reinterpret_cast<const ck_tile::index_t*>(
kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
// index along hidden_size
ck_tile::index_t hidden_id =
__builtin_amdgcn_readfirstlane(hidden_tile_id * FusedMoeTileShape::kN_g);
const auto a_coord = FusedMoePipeline::GetAIndex(); // 2d thread offset, [i_row, i_col]
const auto token_coord =
a_coord[number<0>{}] + sorted_tile_id * FusedMoeTileShape::kM_a;
index_t token_id =
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[token_coord];
ScaleDataType scale =
reinterpret_cast<const ScaleDataType*>(kargs.sorted_weight_ptr)[token_coord];
const auto a_gtile_window = [&]() {
const ADataType* a_ptr = reinterpret_cast<const ADataType*>(kargs.a_ptr);
const auto a_view_ = make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.num_tokens, kargs.dim_size),
make_tuple(kargs.stride_a, 1),
number<FusedMoePipeline::kAlignmentA>{},
number<1>{});
// gather is here
const auto a_gather_view_ = transform_tensor_view(
a_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_pass_through_transform(kargs.dim_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto a_gtile_window_ = make_tile_window(
a_gather_view_,
make_tuple(number<FusedMoeTileShape::kM_a>{}, number<FmhaPipeline::kK_a>{}),
{0, 0});
return a_gtile_window_;
}();
const auto g_gtile_window = [&]() {
const GDataType* g_ptr =
reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * kargs.stride_g_expert +
hidden_id * kargs.stride_g;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(kargs.hidden_size, kargs.dim_size),
make_tuple(kargs.stride_g, 1),
number<FusedMoePipeline::kAlignmentG>{},
number<1>{});
const auto g_view_1_ = pad_tensor_view(
g_view_,
make_tuple(number<FusedMoeShape::kN_g>{}, number<FusedMoeShape::kK_a>{}),
sequence<kPadHiddenSize, kPadDimSize>{});
const auto g_gtile_window_ = make_tile_window(
g_view_1_,
make_tuple(number<FusedMoeTileShape::kN_g>{}, number<FmhaPipeline::kK_a>{}),
{0, 0});
return g_gtile_window_;
}();
const auto u_gtile_window = [&]() {
const UDataType* u_ptr =
reinterpret_cast<const UDataType*>(kargs.u_ptr) +
static_cast<long_index_t>(expert_id) * kargs.stride_u_expert +
hidden_id * kargs.stride_u;
const auto u_view_ = make_naive_tensor_view<address_space_enum::global>(
u_ptr,
make_tuple(kargs.hidden_size, kargs.dim_size),
make_tuple(kargs.stride_u, 1),
number<FusedMoePipeline::kAlignmentU>{},
number<1>{});
const auto u_view_1_ = pad_tensor_view(
u_view_,
make_tuple(number<FusedMoeShape::kN_u>{}, number<FusedMoeShape::kK_a>{}),
sequence<kPadHiddenSize, kPadDimSize>{});
const auto u_gtile_window_ = make_tile_window(
u_view_1_,
make_tuple(number<FusedMoeShape::kN_u>{}, number<FusedMoeShape::kK_a>{}),
{0, 0});
return u_gtile_window_;
}();
const auto d_gtile_window = [&]() {
const DDataType* d_ptr = [&]() {
if constexpr(std::is_same_v<DLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * kargs.stride_d_expert +
hidden_id* kargs.stride_d;
}
else
{
reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * kargs.stride_d_expert +
hidden_id;
}
}();
if constexpr(std::is_same_v<DLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr,
make_tuple(kargs.hidden_size, kargs.dim_size),
make_tuple(kargs.stride_d, 1),
number<FusedMoePipeline::kAlignmentD>{},
number<1>{});
const auto d_view_1_ = pad_tensor_view(
d_view_,
make_tuple(number<FusedMoeShape::kK_y>{}, number<FusedMoeShape::kN_d>{}),
sequence<kPadHiddenSize, kPadDimSize>{});
const auto d_gtile_window_ = make_tile_window(
d_view_1_,
make_tuple(number<FusedMoeShape::kK_y>{}, number<FusedMoeShape::kN_d>{}),
{0, 0});
return d_gtile_window_;
}
else
{
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr,
make_tuple(kargs.dim_size, kargs.hidden_size),
make_tuple(kargs.stride_d, 1),
number<FusedMoePipeline::kAlignmentD>{},
number<1>{});
const auto d_view_1_ = pad_tensor_view(
d_view_,
make_tuple(number<FusedMoeShape::kN_d>{}, number<FusedMoeShape::kK_y>{}),
sequence<kPadHiddenSize, kPadDimSize>{});
const auto d_gtile_window_ = make_tile_window(
d_view_1_,
make_tuple(number<FusedMoeShape::kN_d>{}, number<FusedMoeShape::kK_y>{}),
{0, 0});
return d_gtile_window_;
}
}();
auto o_gtile_window = [&]() {
const ODataType* o_ptr = reinterpret_cast<const ODataType*>(kargs.o_ptr);
const auto o_view_ = make_naive_tensor_view<address_space_enum::global>(
o_ptr,
make_tuple(kargs.num_tokens, kargs.dim_size),
make_tuple(kargs.stride_o, 1),
number<FusedMoePipeline::kAlignmentO>{},
number<1>{});
// gather is here
const auto o_scatter_view_ = transform_tensor_view(
o_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_pass_through_transform(kargs.dim_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto o_gtile_window_ = make_tile_window(
o_scatter_view_,
make_tuple(number<FusedMoeTileShape::kM_a>{}, number<FmhaPipeline::kK_a>{}),
{0, 0});
return o_gtile_window_;
}();
// do compute yeah
FusedMoePipeline{}(a_gtile_window,
g_gtile_window,
u_gtile_window,
d_gtile_window,
o_gtile_window,
scale,
smem_ptr);
tile_id += gridDim.x;
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename FusedMoeTileShape_>
struct FusedMoeTilePartitioner_PersistentSplitD
{
using FusedMoeTileShape = ck_tile::remove_cvref_t<FusedMoeTileShape_>;
static constexpr index_t kM_a = FusedMoeTileShape::kM_a;
static constexpr index_t kN_g = FusedMoeTileShape::kN_g;
static constexpr index_t kN_u = FusedMoeTileShape::kN_u;
static constexpr index_t kK_a = FusedMoeTileShape::kK_a;
static constexpr index_t kN_d = FusedMoeTileShape::kN_d;
static constexpr const char* name = "psd"; // expert x hidden
CK_TILE_DEVICE auto operator()(ck_tile::index_t tile_id,
ck_tile::index_t /*num_sorted_tiles*/,
ck_tile::index_t hidden_size)
{
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const index_t num_hidden_tiles = ck_tile::integer_divide_ceil(hidden_size, kN_g);
const auto [sorted_tile_id, hidden_tile_id] = f(tile_id, num_hidden_tiles);
return ck_tile::make_tuple(sorted_tile_id, hidden_tile_id);
}
// persistent
CK_TILE_HOST static constexpr auto GridSize(index_t num_cu, index_t blocks_per_cu)
{
// TODO: this may need tuning
index_t grids = num_cu * blocks_per_cu;
return dim3(grids);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
struct BlockFmhaPipelineQRKSVSAsync
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using GDataType = remove_cvref_t<typename Problem::GDataType>;
using UDataType = remove_cvref_t<typename Problem::UDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ScaleDataType = remove_cvref_t<typename Problem::ScaleDataType>;
using FusedMoeTileShape = remove_cvref_t<typename Problem::FusedMoeTileShape>;
using VLayout = remove_cvref_t<typename FusedMoeTileShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
Problem::kPadHeadDimV == true);
static constexpr bool kPadSeqLenQ = true;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
#if CK_TILE_FMHA_FWD_FAST_EXP2
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
#endif
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
// minimize occupancy
if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout)
{
return 1;
}
if constexpr(kK0BlockLength <= 32)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
FmhaMask::IsMasking)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 64)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 2;
else
return 3;
}
else if constexpr(kK0BlockLength <= 128)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 256)
{
return 1;
}
}
}();
static constexpr const char* name = "qr_async";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetAIndex()
{
constexpr auto a_dist = Policy::template MakeAGlobalTileDistribution<Problem>();
const auto a_coord = a_dist.calculate_index();
return a_coord;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetOIndex()
{
constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
const auto o_coord = o_dist.calculate_index();
return o_coord;
}
template <typename AGlobalTensorView,
typename GGlobalTileWindow,
typename UGlobalTileWindow,
typename DGlobalTileWindow,
typename OGlobalTensorView>
CK_TILE_DEVICE auto operator()(const AGlobalTensorView& a_gtile_window_tmp,
const GGlobalTileWindow& g_gtile_window_tmp,
const UGlobalTileWindow& u_gtile_window_tmp,
const DGlobalTileWindow& d_gtile_window_tmp,
OGlobalTensorView& o_gtile_window_tmp,
// const void * sorted_weight_ptr,
ScaleDataType scale,
void* smem_ptr,
index_t dim_size,
index_t hidden_size)
{
constexpr auto gemm_0 = Policy::template GetGemm0<Problem>();
constexpr auto gemm_1 = Policy::template GetGemm1<Problem>();
auto a_gtile_window =
make_tile_window(a_gtile_window_tmp.get_bottom_tensor_view(),
a_gtile_window_tmp.get_window_lengths(),
a_gtile_window_tmp.get_window_origin(),
Policy::template MakeAGlobalTileDistribution<Problem>());
auto g_gtile_window =
make_tile_window(g_gtile_window_tmp.get_bottom_tensor_view(),
g_gtile_window_tmp.get_window_lengths(),
g_gtile_window_tmp.get_window_origin(),
Policy::template MakeGGlobalTileDistribution<Problem>());
auto u_gtile_window =
make_tile_window(u_gtile_window_tmp.get_bottom_tensor_view(),
u_gtile_window_tmp.get_window_lengths(),
u_gtile_window_tmp.get_window_origin(),
Policy::template MakeUGlobalTileDistribution<Problem>());
auto d_gtile_window =
make_tile_window(d_gtile_window_tmp.get_bottom_tensor_view(),
d_gtile_window_tmp.get_window_lengths(),
d_gtile_window_tmp.get_window_origin(),
Policy::template MakeDGlobalTileDistribution<Problem>());
auto o_gtile_window =
make_tile_window(o_gtile_window_tmp.get_bottom_tensor_view(),
o_gtile_window_tmp.get_window_lengths(),
o_gtile_window_tmp.get_window_origin(),
Policy::template MakeOGlobalTileDistribution<Problem>());
constexpr auto k_per_block_0 = Problem::FusedMoeTileShape::kK_a;
const index_t loops_0 = (dim_size + k_per_block_0 - 1) / k_per_block_0;
constexpr auto n_per_block_1 = Problem::FusedMoeTileShape::kN_d;
const index_t loops_1 = (dim_size + n_per_block_1 - 1) / n_per_block_1;
auto a_smem_ptr = reinterpret_cast<ADataType*>(smem_ptr) + a_smem_offset;
make_tile_window(make_tensor_view<address_space_enum::lds>(
a_smem_ptr, Policy::template MakeALdsStoreBlockDescriptor<Problem>()),
Policy::template MakeALdsStoreBlockDescriptor<Problem>().get_lengths(),
{0, 0});
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), a_gtile_window);
for(index_t i_0 = 0; i_0 < loops_0; i_0++) {}
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename LSEElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const KElementFunction& /*k_element_func*/,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
DropoutType& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
// K tile in LDS
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
auto k_lds_store = generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf)),
Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf).get_lengths(),
{0, 0, 0});
},
number<Policy::NumPrefetchK>{});
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
auto k_lds_load = generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>(i_buf)),
Policy::template MakeKLdsLoadBlockDescriptor<Problem>(i_buf).get_lengths(),
{0, 0});
},
number<Policy::NumPrefetchK>{});
#else
auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
auto k_lds_load =
make_tile_window(k_lds_Load_view,
Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
{0, 0});
#endif
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window = make_tile_window(
q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>());
q_dram_window.init_raw();
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto q = decltype(load_tile(q_dram_window)){}; // reg = copy(some_tensor_vew)
set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw(q, q_dram_window);
__builtin_amdgcn_sched_barrier(0);
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{
if(num_total_loop <= 0)
{
if constexpr(kStoreLSE)
{
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
return o_acc;
}
__builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
}
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
k_dram_window.init_raw();
constexpr auto k_oob_ck = bool_constant<true>{};
constexpr auto k_pre_np = [&]() {
if constexpr(kPadSeqLenK &&
(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout)))
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(1 <= k0_loops);
static_assert(1 <= k1_loops);
// main loop
do
{
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
if constexpr(k0_loops > 1)
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
k_dram_window,
k_oob_ck,
k_pre_np);
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_access());
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load[number<LdsSeq.at(number<i_k0>{})>{}]);
#else
get_slice_tile(k_lds_load,
sequence<(LdsSeq.at(number<i_k0>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<i_k0>{}) + 1) * kN0, kK0>{}));
#endif
});
}
// TODO: this to fix a bug when loop smaller than 2,
// the following fence/barrier will be scheduled inside 1st loop
if constexpr(k0_loops <= 2)
__builtin_amdgcn_sched_barrier(0);
async_load_fence();
__builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
auto v_buf = load_tile(v_dram_window, bool_constant<false>{});
__builtin_amdgcn_sched_barrier(0);
{ // tail
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load[number<LdsSeq.at(number<k0_loops - 1>{})>{}]);
#else
get_slice_tile(
k_lds_load,
sequence<(LdsSeq.at(number<k0_loops - 1>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
#endif
}
__builtin_amdgcn_sched_barrier(1);
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x += type_convert<SaccDataType>(bias_element_func(y));
#else
x += log2e_v<SaccDataType> *
type_convert<SaccDataType>(bias_element_func(y));
#endif
},
s_acc,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto k_origin = k_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.get_tile_distribution()); // Pcompute{j}
__builtin_amdgcn_sched_barrier(0x7F);
// store & prefetch next v, after the max reduction
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(
v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
}
if constexpr(k1_loops > 1)
{
move_tile_window(
v_dram_window,
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(v_dram_window, bool_constant<false>{}); // load next v_buf
}
__builtin_amdgcn_sched_barrier(0);
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration. alibi does not have this problem
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
}
#else
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
#endif
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}();
#else
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
#endif
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
});
});
if constexpr(kHasDropout)
{
auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
}
const auto p = [&]() {
if constexpr(std::is_same_v<PDataType, fp16_t>)
return impl::cast_tile_pk_fp16_fp32<PDataType>(
tile_elementwise_in(p_compute_element_func, p_compute));
else
return cast_tile<PDataType>(
tile_elementwise_in(p_compute_element_func, p_compute));
}();
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
{
v_buf = load_tile(v_dram_window, bool_constant<false>{}); // load next v_buf
}
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
}
if constexpr(i_k1 < k1_loops - 1)
move_tile_window(v_dram_window, {0, kK1});
});
}
i_total_loops++;
if(i_total_loops < num_total_loop)
{
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier();
async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
}
// tail
{
block_sync_lds();
gemm_1(
o_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
}
} while(i_total_loops < num_total_loop);
// store lse
if constexpr(kStoreLSE)
{
auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
}
else
{
lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]);
}
#else
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
#endif
});
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
// finally, O
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() {
if constexpr(FmhaMask::IsMasking)
{
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
}
else
return 1 / l[i_idx];
}();
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
DropoutType& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
position_encoding,
scale_s,
smem_ptr,
dropout);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
struct FusedMoePipelinePolicy
{
CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords()
{
// TODO:
return 1;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{
// using async
static constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords();
static constexpr index_t data_bytes = sizeof(typename Problem::ADataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentG()
{
static constexpr index_t copy_bytes = [&]() {
if constexpr(Problem::Traits::GateUpPreShuffled)
{
return 4 * 4;
}
else
{
return 4 * GetAsyncCopyDwords();
}
}();
static constexpr index_t data_bytes = sizeof(typename Problem::GDataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentU()
{
static constexpr index_t copy_bytes = [&]() {
if constexpr(Problem::Traits::GateUpPreShuffled)
{
return 4 * 4;
}
else
{
return 4 * GetAsyncCopyDwords();
}
}();
static constexpr index_t data_bytes = sizeof(typename Problem::UDataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentD()
{
static constexpr index_t copy_bytes = [&]() {
if constexpr(Problem::Traits::DownPreShuffled)
{
return 4 * 4;
}
else
{
return 4 * GetAsyncCopyDwords();
}
}();
static constexpr index_t data_bytes = sizeof(typename Problem::DDataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename DataType_>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack()
{
// TODO: this is for 3d layout
return 16 / sizeof(remove_cvref_t<typename Problem::DataType_>);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackA()
{
return GetSmemKPack<typename Problem::ADataType>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackG()
{
return GetSmemKPack<typename Problem::GDataType>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackU()
{
return GetSmemKPack<typename Problem::UDataType>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackD()
{
return GetSmemKPack<typename Problem::DDataType>();
}
template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK()
{
constexpr index_t K_vec = Alignment constexpr index_t K_rem = KPerBlock / K_vec;
if constexpr(get_warp_size() < K_rem)
{
static_assert(K_rem % get_warp_size() == 0);
constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
constexpr index_t K_wav = K_rem / get_warp_size();
static_assert(K_wav <= NumWarps, "not not support thread has repeat along K yet");
constexpr index_t M_wav = NumWarps / K_wav;
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / M_wav;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>,
tuple<sequence<1, 2>, sequence<2>>,
tuple<sequence<1, 0>, sequence<1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
else
{
constexpr index_t K_lan = K_rem;
constexpr index_t M_lan = get_warp_size() / K_lan;
constexpr index_t M_wav = NumWarps;
static_assert(MPerBlock % (M_lan * M_wav) == 0,
"this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_wav, M_lan>, sequence<K_lan, K_vec>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
// optimized version for async
template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK_Async()
{
constexpr index_t K_vec = Alignment constexpr index_t K_rem = KPerBlock / K_vec;
if constexpr(get_warp_size() < K_rem)
{
static_assert(K_rem % get_warp_size() == 0);
constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
constexpr index_t K_wav = K_rem / get_warp_size();
static_assert(K_wav <= NumWarps, "not not support thread has repeat along K yet");
constexpr index_t M_wav = NumWarps / K_wav;
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / M_wav;
// NOTE: no swap, but hard to avoid LDS bank conflict
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>,
tuple<sequence<1, 2>, sequence<2>>,
tuple<sequence<1, 0>, sequence<1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
else
{
constexpr index_t K_lan = K_rem;
constexpr index_t M_lan = get_warp_size() / K_lan;
constexpr index_t M_wav = NumWarps;
static_assert(MPerBlock % (M_lan * M_wav) == 0,
"this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
// NOTE: swapped for LDS load bank conflict free
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_lan, M_wav>, sequence<K_lan, K_vec>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<2>, sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
// Caution: this will require global memory pre-shuffled to follow the mfma layout
// to maximize the L1/L2 channel while skip LDS
/*
(b) n0 n1 n2 k0 k1 k2
klanes
|
nr 4 kr 4 16 8
(b) n0 n1 k0 k1 n2 k2 -> kthreads
| |
V V
waves nlanes
klanes
|
nr kr 4 4 16 8
(b) n0 k0 n1 k1 n2 k2 -> kthreads
| |
V V
waves nlanes
*/
template <typename BlockTile, typename BlockWarps, typename WarpGemm, index_t Alignment>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_MatrixCore_Swizzled_NxK()
{
static_assert(Alignment % WarpGemm::WarpGemmAttribute::Impl::kABKPerLane == 0);
static_assert(BlockWarps{}.at(number<0>{}) == 1 && BlockWarps{}.at(number<2>{}) == 1);
static constexpr index_t NumWarps =
reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
constexpr index_t NPerBlock = BlockTile{}.at(number<1>{});
constexpr index_t KPerBlock = BlockTile{}.at(number<2>{});
constexpr index_t K2 = Alignment;
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t N1 = NumWarps;
static_assert(NPerBlock % (N1 * N2) == 0);
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t K0 = KPerBlock / (K1 * K2);
constexpr index_t N0 = NPerBlock / (N1 * N2);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>,
tuple<sequence<1, 2>, sequence<2>>,
tuple<sequence<1, 0>, sequence<1>>,
sequence<1, 2>,
sequence<0, 2>>{});
if constexpr(get_warp_size() < K_rem)
{
static_assert(K_rem % get_warp_size() == 0);
constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
constexpr index_t K_wav = K_rem / get_warp_size();
static_assert(K_wav <= NumWarps, "not not support thread has repeat along K yet");
constexpr index_t M_wav = NumWarps / K_wav;
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / M_wav;
// NOTE: no swap, but hard to avoid LDS bank conflict
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>,
tuple<sequence<1, 2>, sequence<2>>,
tuple<sequence<1, 0>, sequence<1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
else
{
constexpr index_t K_lan = K_rem;
constexpr index_t M_lan = get_warp_size() / K_lan;
constexpr index_t M_wav = NumWarps;
static_assert(MPerBlock % (M_lan * M_wav) == 0,
"this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
// NOTE: swapped for LDS load bank conflict free
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_lan, M_wav>, sequence<K_lan, K_vec>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<2>, sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeAGlobalTileDistribution()
{
constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kM_a;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentA<Problem>();
return MakeGlobalTileDistribution_SimpleMxK_Async<kMPerBlock,
kKPerBlock,
NumWarps,
Alignment>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGGlobalTileDistribution()
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_g;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentG<Problem>();
return MakeGlobalTileDistribution_SimpleMxK_Async<kNPerBlock,
kKPerBlock,
NumWarps,
Alignment>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeUGlobalTileDistribution()
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_u;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentU<Problem>();
return MakeGlobalTileDistribution_SimpleMxK_Async<kNPerBlock,
kKPerBlock,
NumWarps,
Alignment>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeDGlobalTileDistribution()
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_d;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_y;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentD<Problem>();
return MakeGlobalTileDistribution_SimpleMxK_Async<kNPerBlock,
kKPerBlock,
NumWarps,
Alignment>();
}
template <index_t MPerBlock,
index_t KPerBlock,
index_t NumWarps,
index_t Alignment,
index_t KPack,
index_t NumPrefetch>
CK_TILE_HOST_DEVICE static constexpr auto MakeSmemLoadTileDescriptor_SimpleMxK_Async()
{
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kBlockSize = ck_tile::get_warp_size() * NumWarps; // Problem::kBlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
// constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector =
Alignment; // GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad = KPack; // for async-copy, this pad is between warps
static_assert(warpSize * KVector >= KPerBlock && warpSize * KVector % KPerBlock == 0);
constexpr index_t LanesPerK = KPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = MPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == MPerBlock * KPerBlock / (kBlockSize * KVector));
constexpr index_t BufferSize = NumIssues * NumWarps * (warpSize * KVector + kPad);
constexpr auto lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumPrefetch>{}, // num_buffers
number<NumIssues>{}, // n0
number<NumWarps>{}, // n2
number<LaneGroups>{}, // n1
number<KPerBlock / KPack>{}, // k0
number<KPack>{}), // k1
make_tuple(number<BufferSize>{},
number<NumWarps*(warpSize * KVector + kPad)>{},
number<warpSize * KVector + kPad>{},
number<KPerBlock>{},
number<KPack>{},
number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto lds_block_desc = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<NumPrefetch>{},
number<NumIssues>{},
number<LaneGroups>{},
number<NumWarps>{})),
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<0, 1, 3, 2>{}, sequence<4, 5>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_block_desc;
}
template <index_t MPerBlock,
index_t KPerBlock,
index_t NumWarps,
index_t KPack,
index_t Alignement,
index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto
MakeSmemStoreBlockDescriptor_SimpleMxK_Async(number<IBuf> = number<0>{})
{
constexpr index_t kBlockSize = ck_tile::get_warp_size() * NumWarps; // Problem::kBlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
// constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
// constexpr index_t Alignement = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad =
KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed
static_assert(warpSize * Alignement >= KPerBlock && warpSize * Alignement % KPerBlock == 0);
constexpr index_t LanesPerK =
KPerBlock / Alignement; // how many lane (within a wave) to load K
constexpr index_t LaneGroups =
warpSize /
LanesPerK; // how many groups (within a wave), they may load different N, but same K
constexpr index_t NumIssues = MPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == MPerBlock * KPerBlock / (BlockSize * Alignement));
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<Alignement>{}), // k1
make_tuple(number<NumWarps*(warpSize * Alignement + kPad)>{},
number<KPerBlock>{},
number<warpSize * Alignement + kPad>{},
number<Alignement>{},
number<1>{}),
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
number<Alignement>{},
number<1>{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<Alignement>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return k_lds_block_desc_issues_warps_lanes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeASmemLoadTileDistribution()
{
constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kM_a;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentA<Problem>();
constexpr index_t KPack = GetSmemKPackA<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchA;
return MakeSmemLoadTileDescriptor_SimpleMxK_Async<kMPerBlock,
kKPerBlock,
NumWarps,
Alignment,
KPack,
NumPrefetch>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeASmemStoreTileDistribution()
{
constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kM_a;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentA<Problem>();
constexpr index_t KPack = GetSmemKPackA<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchA;
return MakeSmemStoreBlockDescriptor_SimpleMxK_Async<kMperBlock,
kKPerBlock,
kBlockSize,
NumWarps,
KPack,
Alignment>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGSmemLoadTileDistribution()
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_g;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentG<Problem>();
constexpr index_t KPack = GetSmemKPackG<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchG;
return MakeSmemLoadTileDescriptor_SimpleMxK_Async<kNPerBlock,
kKPerBlock,
NumWarps,
Alignment,
KPack,
NumPrefetch>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGSmemStoreTileDistribution()
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_g;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentG<Problem>();
constexpr index_t KPack = GetSmemKPackG<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchG;
return MakeSmemStoreTileDescriptor_SimpleMxK_Async<kNPerBlock,
kKPerBlock,
NumWarps,
Alignment,
KPack,
NumPrefetch>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeUSmemLoadTileDistribution()
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_u;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentU<Problem>();
constexpr index_t KPack = GetSmemKPackU<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchU;
return MakeSmemLoadTileDescriptor_SimpleMxK_Async<kNPerBlock,
kKPerBlock,
NumWarps,
Alignment,
KPack,
NumPrefetch>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeDSmemLoadTileDistribution()
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_d;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_y;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignmentD<Problem>();
constexpr index_t KPack = GetSmemKPackD<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchD;
return MakeSmemLoadTileDescriptor_SimpleMxK_Async<kNPerBlock,
kKPerBlock,
NumWarps,
Alignment,
KPack,
NumPrefetch>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetGemm0()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::ADataType,
typename Problem::GDataType, // UDataType is the same
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::FusedMoeTileShape::kM_a,
Problem::FusedMoeTileShape::kN_g * 2,
Problem::FusedMoeTileShape::kK_a>>;
constexpr auto warp_gemm = []() {
return WarpGemmMfmaDispatcher<
typename Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
Problem::FusedMoeTileShape::Gemm0WarpTile::at(number<0>{}),
Problem::FusedMoeTileShape::Gemm0WarpTile::at(number<1>{}),
Problem::FusedMoeTileShape::Gemm0WarpTile::at(number<2>{}),
true /*TransposeC*/>{};
}();
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<
typename Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
typename Problem::FusedMoeTileShape::Gemm0BlockWarps,
decltype(warp_gemm)>;
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetGemm1()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::YDataType,
typename Problem::DDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::FusedMoeTileShape::kM_a,
Problem::FusedMoeTileShape::kN_d,
Problem::FusedMoeTileShape::kK_y>>;
constexpr auto warp_gemm = []() {
return WarpGemmMfmaDispatcher<
typename Problem::YDataType,
typename Problem::DDataType,
typename Problem::AccDataType,
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<0>{}),
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<1>{}),
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<2>{}),
true /*TransposeC*/>{};
}();
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<
typename Problem::YDataType,
typename Problem::DDataType,
typename Problem::AccDataType,
typename Problem::FusedMoeTileShape::Gemm1BlockWarps,
decltype(warp_gemm)>;
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename ADataType_,
typename GDataType_,
typename UDataType_,
typename DDataType_,
typename ODataType_,
typename AccDataType_,
typename ScaleDataType_,
typename GateActivation_, // = ck_tile::element_wise::Silu,
typename FusedMoeTileShape_,
typename Traits_>
struct FusedMoePipelineProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using YDataType = ADataType;
using GDataType = remove_cvref_t<GDataType_>;
using UDataType = remove_cvref_t<UDataType_>;
using DDataType = remove_cvref_t<DDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ScaleDataType = remove_cvref_t<ScaleDataType_>;
using FusedMoeTileShape = remove_cvref_t<FusedMoeTileShape_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = FusedMoeTileShape::NumWarps * get_warp_size();
// attributes from traits
// static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
// static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
// static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
// static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
// static constexpr auto BiasEnum = Traits::BiasEnum;
// static constexpr bool kStoreLSE = Traits::kStoreLSE;
// static constexpr bool kHasDropout = Traits::kHasDropout;
// static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
using GateActivation = remove_cvref_t<typename Traits::GateActivation_>;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
/*
tensors:
1. act (A): input feature map
2. gate (G): B matrix for first gemm, output will do activation(Silu)
3. up (U): B matrix for first gemm
4. down (D): B matrix for second gemm
N_d
/ \
+----------+ |
| Down | |
x----------x |
hidden hidden K_d | | |
N_g N_u x----------x |
| +------x-----x------+------x-----x------+ | | |
dim | | Gate | | | Up | | | | | |
contiguous | | | | | | | | | | |
| | | | | | | | | | |
v +------x-----x------+------x-----x------+ +----------+ V
K_a | | | | | contiguous
/ \ v v v v |
+---------+ +------x-----x------+------x-----x------+ |
M_a | A | | | | | | | | |
+---------+ +------x-----x------+------x-----x------+ |
--------> | | |
contiguous | V V
| x-----x +----------+
+-------------> | Y | ---------> | Out(O) |
SILU x-----x +----------+
K_y = N_g = N_u dim
*/
template <typename BlockTile_, // sequence<M_a, N_g, N_u, K_a, N_d
typename Gemm0BlockWarps_,
typename Gemm0WarpTile_,
typename Gemm1BlockWarps_,
typename Gemm1WarpTile_,
bool IsDLayoutRowMajor_>
struct FusedMoeTileShape
{
using BlockTile = remove_cvref_t<BlockTile_>;
using Gemm0BlockWarps = remove_cvref_t<Gemm0BlockWarps_>;
using Gemm0WarpTile = remove_cvref_t<Gemm0WarpTile_>;
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
static constexpr index_t NumWarps =
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}));
static constexpr index_t kM_a = BlockTile::at(number<0>{});
static constexpr index_t kN_g = BlockTile::at(number<1>{});
static constexpr index_t kN_u = BlockTile::at(number<2>{});
static constexpr index_t kK_a = BlockTile::at(number<3>{});
static constexpr index_t kN_d = BlockTile::at(number<4>{});
static_assert(kN_g == kN_u);
static constexpr index_t kK_y = kN_g;
static constexpr index_t kM_0 = kM_a;
static constexpr index_t kN_0 = kN_g; // note N will x2
static constexpr index_t kK_0 = kK_a;
static constexpr index_t kM_1 = kM_0;
static constexpr index_t kN_1 = kN_d;
static constexpr index_t kK_1 = kN_g;
// d, rowmajor : hidden*dim, colmajor : dim*hidden (vLLM use this layout)
static constexpr bool IsDLayoutRowMajor = IsDLayoutRowMajor_;
using DLayout = std::conditional_t<IsDLayoutRowMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <bool GateUpPreShuffled_ = false,
bool DownPreShuffled_ = false,
index_t NumPrefetchA_ = 2,
index_t NumPrefetchG_ = 2,
index_t NumPrefetchU_ = 2,
index_t NumPrefetchD_ = 2,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct FusedMoeTraits
{
static constexpr bool GateUpPreShuffled = GateUpPreShuffled_;
static constexpr bool DownPreShuffled = DownPreShuffled_;
static constexpr index_t NumPrefetchA = NumPrefetchA_;
static constexpr index_t NumPrefetchG = NumPrefetchG_;
static constexpr index_t NumPrefetchU = NumPrefetchU_;
static constexpr index_t NumPrefetchD = NumPrefetchD_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
} // namespace ck_tile
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(tile_example_permute EXCLUDE_FROM_ALL permute.cpp)
if(NOT DEFINED PERMUTE_USE_ALTERNATIVE_IMPL)
# set(PERMUTE_USE_ALTERNATIVE_IMPL false)
set(PERMUTE_USE_ALTERNATIVE_IMPL true)
endif()
if(PERMUTE_USE_ALTERNATIVE_IMPL)
target_compile_options(tile_example_permute PRIVATE -DPERMUTE_USE_ALTERNATIVE_IMPL)
target_sources(tile_example_permute PRIVATE alternative_impl/matrix_core_swizzle.cpp)
endif()
# target_compile_options(tile_example_permute PRIVATE -v --save-temps -Wno-gnu-line-marker)
......@@ -36,3 +36,11 @@ or you can try the smoke_test
# in the root of ck_tile, after you build this example
sh example/ck_tile/06_permute/script/smoke_test.sh
```
### alternative implementation
we have an alternative implementation under `alternative_impl/` folder, that can swizzle the tensor to be more friendly for data loading for matrix core layout. This can be enabled when dealing with a `rank-7` tensor, with a fixed pattern of either `0,1,4,2,5,3,6` or `0,1,2,4,5,3,6`. There are other shape limitation of this implementation, check the source code of `permute.cpp` for detail.
```
# example
./build/bin/tile_example_permute -shape=3,6,4,32,16,2,8 -perm=0,1,4,2,5,3,6 # b_n0_k0_n1_k1_n2_k2
./build/bin/tile_example_permute -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 # b_n0_n1_k0_k1_n2_k2
```
#include "matrix_core_swizzle.hpp"
#include "matrix_core_swizzle_kernel.hpp"
float matrix_core_swizzle(matrix_core_swizzle_traits t,
matrix_core_swizzle_args a,
const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp16") == 0)
{
if(t.inst.compare("32x32x8") == 0)
{
constexpr int BLOCK_SIZE = 256;
constexpr int NPerBlock = 256;
constexpr int KPerBlock = 128;
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_32x32x8_F16;
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
}
else if(t.inst.compare("16x16x16") == 0)
{
constexpr int BLOCK_SIZE = 256;
constexpr int NPerBlock = 256;
constexpr int KPerBlock = 128;
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_16x16x16_F16;
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
}
}
return -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "matrix_core_swizzle_kernel.hpp"
#include <string>
struct matrix_core_swizzle_traits
{
std::string data_type; // fp16 only
std::string inst; // 32x32x8, 16x16x16
std::string permute; //
};
using matrix_core_swizzle_args = matrix_core_swizzle_host_args;
// host API
float matrix_core_swizzle(matrix_core_swizzle_traits,
matrix_core_swizzle_args,
const ck_tile::stream_config&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
enum class matrix_core_inst_enum
{
MFMA_32x32x8_F16 = 0,
MFMA_16x16x16_F16 = 1,
};
namespace detail {
template <matrix_core_inst_enum>
struct to_warp_gemm;
template <>
struct to_warp_gemm<matrix_core_inst_enum::MFMA_32x32x8_F16>
{
using type = ck_tile::WarpGemmMfmaF16F16F32M32N32K8;
};
template <>
struct to_warp_gemm<matrix_core_inst_enum::MFMA_16x16x16_F16>
{
using type = ck_tile::WarpGemmMfmaF16F16F32M16N16K16;
};
} // namespace detail
template <matrix_core_inst_enum Inst>
using to_warp_gemm_t = typename detail::to_warp_gemm<Inst>::type;
enum class matrix_core_permute_style
{
permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
};
// assume this is B matrix, originally we have batch*n*k
// now batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
// assume using 32x32x8-f16, 4 waves and extend the KPerLane to 8xfp16(dwordx4)
//
// 4(waves) 32(mfma_m lane)
// | |
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 -> 8(thread loading)
// nr kr |
// nr 4 32 kr 2 8 2(klane)
//
// permute: 0,1,4,2,5,3,6
// or
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*n1*k0*k1*n2*k2 -> 8(thread loading)
// permute: 0,1,2,4,5,3,6
//
// this kernel only deal with fp16/bf16 data(16bit), and use 2d block size to do the swizzling
// for simplicity, only consider n/k is multiple of block-size
// independend host arg with no template
struct matrix_core_swizzle_host_args
{
const void* p_src;
void* p_dst;
int32_t batch;
int32_t n;
int32_t k;
};
// NOTE: this kernel could follow the style of generic permute kernel
// but here we pass in fixed layout as template arg and generate different kernel instance
// purposely
template <int BLOCK_SIZE_ = 256,
int NPerBlock_ = 256,
int KPerBlock_ = 128,
matrix_core_permute_style pstyle_ =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2,
matrix_core_inst_enum Inst_ = matrix_core_inst_enum::MFMA_32x32x8_F16>
struct matrix_core_swizzle_kernel
{
using karg = matrix_core_swizzle_host_args;
using harg = matrix_core_swizzle_host_args;
static constexpr int BLOCK_SIZE = BLOCK_SIZE_;
static constexpr int NPerBlock = NPerBlock_;
static constexpr int KPerBlock = KPerBlock_;
static constexpr matrix_core_permute_style pstyle = pstyle_;
static constexpr matrix_core_inst_enum Inst = Inst_;
static constexpr ck_tile::index_t Alignment = 8;
karg a;
dim3 grids;
using WarpGemm = to_warp_gemm_t<Inst>;
__host__ matrix_core_swizzle_kernel(harg h)
{
a = h;
ck_tile::index_t ns = (h.n + NPerBlock - 1) / NPerBlock;
ck_tile::index_t ks = (h.k + KPerBlock - 1) / KPerBlock;
grids = dim3(ks, ns, h.batch);
}
__host__ bool is_applicable(harg h) { return h.n % NPerBlock == 0 && h.k % KPerBlock == 0; }
__host__ void operator()(const ck_tile::stream_config& s) const
{
ck_tile::kentry<BLOCK_SIZE, 1, kernel><<<grids, BLOCK_SIZE, 0, s.stream_id_>>>(a);
}
struct kernel
{
__device__ static constexpr auto get_src_dist()
{
using namespace ck_tile;
constexpr index_t K2 = Alignment;
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t N1 = BLOCK_SIZE / get_warp_size();
static_assert(NPerBlock % (N1 * N2) == 0);
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t K0 = KPerBlock / (K1 * K2);
constexpr index_t N0 = NPerBlock / (N1 * N2);
// clang-format off
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,// 0
// 1 2 3 4 5 6
tuple<sequence<N0>, sequence<N1>, sequence<N2>, sequence<K0>, sequence<K1>, sequence<K2>>,
// N1 K1 N2
tuple<sequence<2>, sequence<5, 3>>,
tuple<sequence<0>, sequence<0, 0>>,
// N0 K0 K2
sequence<1, 4, 6>,
sequence<0, 0, 0>>{});
// clang-format on
}
__device__ static constexpr auto get_dst_dist()
{
using namespace ck_tile;
constexpr index_t K2 = Alignment;
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t N1 = BLOCK_SIZE / get_warp_size();
static_assert(NPerBlock % (N1 * N2) == 0);
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t K0 = KPerBlock / (K1 * K2);
constexpr index_t N0 = NPerBlock / (N1 * N2);
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
{
// clang-format off
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,// 0
// 1 2 3 4 5 6
tuple<sequence<N0>, sequence<K0>, sequence<N1>, sequence<K1>, sequence<N2>, sequence<K2>>,
// N1 K1 N2
tuple<sequence<3>, sequence<4, 5>>,
tuple<sequence<0>, sequence<0, 0>>,
// N0 K0 K2
sequence<1, 2, 6>,
sequence<0, 0, 0>>{});
// clang-format on
}
else
{
// clang-format off
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,// 0
// 1 2 3 4 5 6
tuple<sequence<N0>, sequence<N1>, sequence<K0>, sequence<K1>, sequence<N2>, sequence<K2>>,
// N1 K1 N2
tuple<sequence<2>, sequence<4, 5>>,
tuple<sequence<0>, sequence<0, 0>>,
// N0 K0 K2
sequence<1, 3, 6>,
sequence<0, 0, 0>>{});
// clang-format on
}
}
__device__ void operator()(karg a_)
{
using namespace ck_tile;
index_t i_k = blockIdx.x;
index_t i_n = blockIdx.y;
index_t i_b = blockIdx.z;
constexpr index_t k2 = Alignment;
constexpr index_t n2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t k1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t n1 = BLOCK_SIZE / get_warp_size();
const index_t k0 = a_.k / (k1 * k2);
const index_t n0 = a_.n / (n1 * n2);
constexpr index_t k2_tile = Alignment;
constexpr index_t n2_tile = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t k1_tile = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t n1_tile = BLOCK_SIZE / get_warp_size();
constexpr index_t k0_tile = KPerBlock / (k1_tile * k2_tile);
constexpr index_t n0_tile = NPerBlock / (n1_tile * n2_tile);
const fp16_t* p_src = reinterpret_cast<const fp16_t*>(a_.p_src) + i_b * a_.k * a_.n;
fp16_t* p_dst = reinterpret_cast<fp16_t*>(a_.p_dst) + i_b * a_.k * a_.n;
const auto src_view = [&]() {
const auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_src,
make_tuple(n0, n1, n2, k0, k1, k2),
number<Alignment>{}); // control vector load
return tmp;
}();
const auto src_window = make_tile_window(src_view,
make_tuple(number<n0_tile>{},
number<n1_tile>{},
number<n2_tile>{},
number<k0_tile>{},
number<k1_tile>{},
number<k2_tile>{}),
{i_n * n0_tile, 0, 0, i_k * k0_tile, 0, 0},
get_src_dist());
auto dst_view = [&]() {
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
{
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(n0, k0, n1, k1, n2, k2),
number<Alignment>{}); // control vector load
return tmp;
}
else
{
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(n0, n1, k0, k1, n2, k2),
number<Alignment>{}); // control vector load
return tmp;
}
}();
auto dst_window = [&]() {
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
{
return make_tile_window(dst_view,
make_tuple(number<n0_tile>{},
number<k0_tile>{},
number<n1_tile>{},
number<k1_tile>{},
number<n2_tile>{},
number<k2_tile>{}),
{i_n * n0_tile, i_k * k0_tile, 0, 0, 0, 0},
get_dst_dist());
}
else
{
return make_tile_window(dst_view,
make_tuple(number<n0_tile>{},
number<n1_tile>{},
number<k0_tile>{},
number<k1_tile>{},
number<n2_tile>{},
number<k2_tile>{}),
{i_n * n0_tile, 0, i_k * k0_tile, 0, 0, 0},
get_dst_dist());
}
}();
// actual load store
auto src_tile = load_tile(src_window);
// now we only swap the distribution from src to dst, no extra movement occurs
auto dst_tile = make_static_distributed_tensor<fp16_t>(get_dst_dist());
dst_tile.get_thread_buffer() = src_tile.get_thread_buffer();
// final store
store_tile(dst_window, dst_tile);
}
};
};
......@@ -14,6 +14,10 @@
#include <utility>
#include <vector>
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
#include "alternative_impl/matrix_core_swizzle.hpp"
#endif
namespace detail {
template <int bytes>
struct to_integer_type;
......@@ -191,7 +195,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
auto x_shape = decode_vec(arg_parser.get_str("shape"));
auto shape = decode_vec(arg_parser.get_str("shape"));
auto perm = decode_vec(arg_parser.get_str("perm"));
int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat");
......@@ -206,7 +210,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
return false;
}
ck_tile::HostTensor<DataType> x(x_shape);
ck_tile::HostTensor<DataType> x(shape);
ck_tile::FillUniformDistributionIntegerValue<DataType>{-15, 15, seed}(x);
std::vector<ck_tile::index_t> y_shape = [&]() {
......@@ -217,7 +221,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// std::cout << " i:" << i << ", perm:" << perm[i] << ", rak:" <<
// static_cast<int>(rank)
// << std::endl;
tmp[i] = x_shape[perm[i]];
tmp[i] = shape[perm[i]];
}
// std::cout << "@@@" << tmp << std::endl;
return tmp;
......@@ -230,26 +234,78 @@ bool run(const ck_tile::ArgParser& arg_parser)
x_buf.ToDevice(x.data());
permute_args args;
args.p_src = x_buf.GetDeviceBuffer();
args.p_dst = y_buf.GetDeviceBuffer();
args.rank = rank;
std::copy(x_shape.begin(), x_shape.end(), args.shape);
std::copy(perm.begin(), perm.end(), args.perm);
permute_traits trait;
trait.data_type = data_type;
std::cout << "[" << data_type << "] shape:" << x_shape << "->" << y_shape
<< ", permute:" << perm << std::flush;
std::cout << "[" << data_type << "] shape:" << shape << "->" << y_shape << ", permute:" << perm
<< std::flush;
ck_tile::stream_config stream_config{nullptr,
true,
/* log_level = */ (kname ? 1 : 0),
stream_warmup,
stream_repeat};
float ave_time = 0.f;
auto run_permute = [&]() {
permute_traits t;
t.data_type = data_type;
permute_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.rank = rank;
std::copy(shape.begin(), shape.end(), a.shape);
std::copy(perm.begin(), perm.end(), a.perm);
return permute(t, a, stream_config);
};
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
if(rank == 7 && (arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") ||
arg_parser.get_str("perm") == std::string("0,1,2,4,5,3,6")))
{
matrix_core_swizzle_traits t;
t.data_type = data_type;
t.permute = arg_parser.get_str("perm");
matrix_core_swizzle_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.batch = shape[0];
a.n = shape[1] * shape[2] * shape[3];
a.k = shape[4] * shape[5] * shape[6];
if(shape[6] == 8 && shape[3] == 32 && shape[5] == 2 && shape[2] == 4 && shape[4] % 8 == 0 &&
shape[1] % 2 == 0)
{
// 32x32x8 inst
// perm=0,1,4,2,5,3,6
// y_shape=*,2x,8x,4,2,32,8 (3,6,16,4,2,32,8)
// shape = *,2x,4,32,8x,2,8 (3,6,4,32,16,2,8)
t.inst = "32x32x8";
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
float ave_time = permute(trait, args, stream_config);
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else if(shape[6] == 8 && shape[3] == 16 && shape[5] == 4 && shape[2] == 4 &&
shape[4] % 4 == 0 && shape[1] % 4 == 0)
{
// 16x16x16 inst
// perm=0,1,4,2,5,3,6
// y_shape=*,4x,4x,4,4,16,8
// shape = *,4x,4,16,4x,4,8 (3,8,4,16,16,4,8)
t.inst = "16x16x16";
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else
{
ave_time = run_permute();
}
}
else
#endif
{
ave_time = run_permute();
}
std::cout << ", time:" << ave_time << "ms" << std::flush;
bool pass = true;
......
......@@ -9,6 +9,14 @@ if [ $# -ge 1 ] ; then
set -x
fi
$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
echo "------------------------------------------------------------------"
for prec in "fp8" "fp16" "fp32" ; do
$EXE -prec=$prec -shape=3,8 -perm=1,0 $COMMON_ARGS
......
......@@ -4,7 +4,9 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
namespace ck_tile {
......
......@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
namespace ck_tile {
......
......@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
namespace ck_tile {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment