Commit cdfceb0a authored by Astha Rai's avatar Astha Rai
Browse files

Merge branch 'codegen_hiprtc' of github.com:ROCm/composable_kernel into codegen_hiprtc

parents b46349df 3b9a77df
...@@ -47,10 +47,16 @@ struct FmhaFwdSplitKVKernel ...@@ -47,10 +47,16 @@ struct FmhaFwdSplitKVKernel
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
static constexpr bool kMergeNumHeadGroupsSeqLenQ =
FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>; using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kHasMask = FmhaMask::IsMasking;
static_assert(!kMergeNumHeadGroupsSeqLenQ ||
(kMergeNumHeadGroupsSeqLenQ && BiasEnum == BlockAttentionBiasEnum::NO_BIAS &&
!kHasMask));
// clang-format off // clang-format off
template <typename T> struct t2s; template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; }; template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
...@@ -476,15 +482,20 @@ struct FmhaFwdSplitKVKernel ...@@ -476,15 +482,20 @@ struct FmhaFwdSplitKVKernel
} }
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead, ck_tile::index_t nhead_q,
ck_tile::index_t nhead_kv,
ck_tile::index_t max_seqlen_q, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_splits) ck_tile::index_t num_splits)
{ {
ck_tile::index_t nhead_ = kMergeNumHeadGroupsSeqLenQ ? nhead_kv : nhead_q;
ck_tile::index_t max_seqlen_q_ =
max_seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? nhead_q / nhead_kv : 1);
// TODO: this may need tuning // TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * return dim3(ck_tile::integer_divide_ceil(max_seqlen_q_, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits, ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits,
nhead, nhead_,
batch_size); batch_size);
} }
...@@ -562,7 +573,7 @@ struct FmhaFwdSplitKVKernel ...@@ -562,7 +573,7 @@ struct FmhaFwdSplitKVKernel
// # of required blocks is different in each groups, terminate unnecessary blocks // # of required blocks is different in each groups, terminate unnecessary blocks
// earlier // earlier
if(kargs.seqlen_q <= i_m0) if(kargs.seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) <= i_m0)
{ {
return; return;
} }
...@@ -617,30 +628,60 @@ struct FmhaFwdSplitKVKernel ...@@ -617,30 +628,60 @@ struct FmhaFwdSplitKVKernel
} }
// for simplicity, batch stride we just modify the pointer // for simplicity, batch stride we just modify the pointer
const index_t i_nhead_k =
(kMergeNumHeadGroupsSeqLenQ ? i_nhead : i_nhead / kargs.nhead_ratio_qk);
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) + const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q + static_cast<long_index_t>(i_nhead) *
(kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
kargs.nhead_stride_q +
batch_offset_q; batch_offset_q;
const KDataType* k_ptr = const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
reinterpret_cast<const KDataType*>(kargs.k_ptr) + static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + batch_offset_k;
batch_offset_k; const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
const VDataType* v_ptr = static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v +
reinterpret_cast<const VDataType*>(kargs.v_ptr) + batch_offset_v;
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
batch_offset_v;
ODataType* o_acc_ptr = reinterpret_cast<ODataType*>(kargs.o_acc_ptr) + ODataType* o_acc_ptr = reinterpret_cast<ODataType*>(kargs.o_acc_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc + static_cast<long_index_t>(i_nhead) *
(kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
kargs.nhead_stride_o_acc +
batch_offset_o_acc + i_split * kargs.split_stride_o_acc; batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
// Q/K/V DRAM and DRAM window // Q/K/V DRAM and DRAM window
const auto q_dram = [&]() { const auto q_dram = [&] {
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto q_dram_naive = [&] {
q_ptr, if constexpr(kMergeNumHeadGroupsSeqLenQ)
make_tuple(kargs.seqlen_q, kargs.hdim_q), {
make_tuple(kargs.stride_q, 1), // reshape: (nhead_ratio_qk, seqlen_q, hdim_q) -> (nhead_ratio_qk * seqlen_q,
number<FmhaPipeline::kAlignmentQ>{}, // hdim_q)
number<1>{}); const auto view = make_naive_tensor_view<address_space_enum::global>(
q_ptr,
make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.nhead_stride_q, kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQ>{},
number<1>{});
return transform_tensor_view(
view,
make_tuple(
make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)),
make_pass_through_transform(kargs.hdim_q)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
q_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQ>{},
number<1>{});
}
}();
if constexpr(FmhaPipeline::kQLoadOnce) if constexpr(FmhaPipeline::kQLoadOnce)
{ {
return pad_tensor_view( return pad_tensor_view(
...@@ -729,7 +770,7 @@ struct FmhaFwdSplitKVKernel ...@@ -729,7 +770,7 @@ struct FmhaFwdSplitKVKernel
} }
}(); }();
auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { auto k_page_block_navigator = [&, i_batch_ = i_batch]() {
if constexpr(kIsPagedKV) if constexpr(kIsPagedKV)
{ {
const auto* block_indices = const auto* block_indices =
...@@ -739,8 +780,7 @@ struct FmhaFwdSplitKVKernel ...@@ -739,8 +780,7 @@ struct FmhaFwdSplitKVKernel
integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size); integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
const long_index_t fixed_offset = const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) * static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k;
kargs.nhead_stride_k;
return make_page_block_navigator<const KDataType, 0>( return make_page_block_navigator<const KDataType, 0>(
kargs.k_ptr, kargs.k_ptr,
...@@ -760,7 +800,7 @@ struct FmhaFwdSplitKVKernel ...@@ -760,7 +800,7 @@ struct FmhaFwdSplitKVKernel
} }
}(); }();
auto v_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { auto v_page_block_navigator = [&, i_batch_ = i_batch]() {
if constexpr(kIsPagedKV) if constexpr(kIsPagedKV)
{ {
const auto* block_indices = const auto* block_indices =
...@@ -770,8 +810,7 @@ struct FmhaFwdSplitKVKernel ...@@ -770,8 +810,7 @@ struct FmhaFwdSplitKVKernel
integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size); integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
const long_index_t fixed_offset = const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) * static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v;
kargs.nhead_stride_v;
return make_page_block_navigator<const VDataType, 1>( return make_page_block_navigator<const VDataType, 1>(
kargs.v_ptr, kargs.v_ptr,
...@@ -842,19 +881,40 @@ struct FmhaFwdSplitKVKernel ...@@ -842,19 +881,40 @@ struct FmhaFwdSplitKVKernel
// lse acc // lse acc
auto lse_acc_dram_window = [&, i_nhead_ = i_nhead, i_split_ = i_split]() { auto lse_acc_dram_window = [&, i_nhead_ = i_nhead, i_split_ = i_split]() {
constexpr auto lse_acc_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{}); constexpr auto lse_acc_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
LSEDataType* lse_acc_ptr = LSEDataType* lse_acc_ptr = reinterpret_cast<LSEDataType*>(kargs.lse_acc_ptr) +
reinterpret_cast<LSEDataType*>(kargs.lse_acc_ptr) + static_cast<long_index_t>(i_nhead_) *
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse_acc + (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
batch_offset_lse_acc + i_split_ * kargs.split_stride_lse_acc; kargs.nhead_stride_lse_acc +
batch_offset_lse_acc + i_split_ * kargs.split_stride_lse_acc;
const auto lse_acc_dram = [&]() {
const auto lse_acc_dram_naive = const auto lse_acc_dram = [&] {
make_naive_tensor_view<address_space_enum::global>(lse_acc_ptr, const auto lse_acc_dram_naive = [&] {
make_tuple(kargs.seqlen_q), if constexpr(kMergeNumHeadGroupsSeqLenQ)
make_tuple(1), {
number<1>{}, // reshape: (nhead_ratio_qk, seqlen_q) -> (nhead_ratio_qk * seqlen_q)
number<1>{}); const auto view = make_naive_tensor_view<address_space_enum::global>(
lse_acc_ptr,
make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q),
make_tuple(kargs.nhead_stride_lse_acc, 1),
number<1>{},
number<1>{});
return transform_tensor_view(view,
make_tuple(make_merge_transform(make_tuple(
kargs.nhead_ratio_qk, kargs.seqlen_q))),
make_tuple(sequence<0, 1>{}),
make_tuple(sequence<0>{}));
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
lse_acc_ptr,
make_tuple(kargs.seqlen_q),
make_tuple(1),
number<1>{},
number<1>{});
}
}();
return pad_tensor_view( return pad_tensor_view(
lse_acc_dram_naive, lse_acc_dram_window_lengths, sequence<kPadSeqLenQ>{}); lse_acc_dram_naive, lse_acc_dram_window_lengths, sequence<kPadSeqLenQ>{});
}(); }();
...@@ -953,13 +1013,37 @@ struct FmhaFwdSplitKVKernel ...@@ -953,13 +1013,37 @@ struct FmhaFwdSplitKVKernel
}(); }();
// Oacc DRAM and Oacc DRAM window // Oacc DRAM and Oacc DRAM window
auto o_acc_dram = [&]() { auto o_acc_dram = [&] {
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto o_acc_dram_naive = [&] {
o_acc_ptr, if constexpr(kMergeNumHeadGroupsSeqLenQ)
make_tuple(kargs.seqlen_q, kargs.hdim_v), {
make_tuple(kargs.stride_o_acc, 1), // reshape: (nhead_ratio_qk, seqlen_q, hdim_v) -> (nhead_ratio_qk * seqlen_q,
number<FmhaPipeline::kAlignmentOacc>{}, // hdim_v)
number<1>{}); const auto view = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr,
make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.nhead_stride_o_acc, kargs.stride_o_acc, 1),
number<FmhaPipeline::kAlignmentOacc>{},
number<1>{});
return transform_tensor_view(
view,
make_tuple(
make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)),
make_pass_through_transform(kargs.hdim_v)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.stride_o_acc, 1),
number<FmhaPipeline::kAlignmentOacc>{},
number<1>{});
}
}();
return pad_tensor_view( return pad_tensor_view(
o_acc_dram_naive, o_acc_dram_naive,
......
...@@ -94,16 +94,17 @@ struct BlockFmhaFwdSplitKVPipelineProblem ...@@ -94,16 +94,17 @@ struct BlockFmhaFwdSplitKVPipelineProblem
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits // attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = Traits::kIsPagedKV; static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
// extract tile size attributes to remove dependency on traits // extract tile size attributes to remove dependency on traits
......
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
namespace ck_tile { namespace ck_tile {
/// NOTICE: we no-longer use this pipeline.
// This pipeline is qkv all located in LDS // This pipeline is qkv all located in LDS
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQSKSVSDefaultPolicy> template <typename Problem_, typename Policy_ = BlockFmhaPipelineQSKSVSDefaultPolicy>
struct [[deprecated]] BlockFmhaPipelineQSKSVS struct BlockFmhaPipelineQSKSVS
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>; using Policy = remove_cvref_t<Policy_>;
...@@ -51,6 +51,24 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -51,6 +51,24 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kBlockPerCu = []() { static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1) if constexpr(Problem::kBlockPerCu != -1)
...@@ -81,6 +99,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -81,6 +99,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
static constexpr const char* name = "qs"; static constexpr const char* name = "qs";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
...@@ -95,6 +115,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -95,6 +115,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename QElementFunction, typename QElementFunction,
typename KElementFunction, typename KElementFunction,
...@@ -114,6 +135,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -114,6 +135,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
const VElementFunction& v_element_func, const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func, const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& /* unused_randval_dram_block_window_tmp */,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func, const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func, const SAccElementFunction& s_acc_element_func,
...@@ -122,7 +144,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -122,7 +144,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr) const void* smem_ptr,
DropoutType& /* unused_dropout */) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
...@@ -222,11 +245,11 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -222,11 +245,11 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
{seqlen_k_start, 0}); {seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window( auto bias_dram_window =
bias_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(), bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto v_dram_window = auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
...@@ -583,6 +606,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -583,6 +606,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
...@@ -590,11 +614,13 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -590,11 +614,13 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr) const void* smem_ptr,
DropoutType& dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
...@@ -604,6 +630,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -604,6 +630,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
identity{}, identity{},
bias_dram_block_window_tmp, bias_dram_block_window_tmp,
identity{}, identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp, lse_dram_block_window_tmp,
identity{}, identity{},
identity{}, identity{},
...@@ -612,7 +639,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -612,7 +639,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
mask, mask,
position_encoding, position_encoding,
scale_s, scale_s,
smem_ptr); smem_ptr,
dropout);
} }
}; };
......
...@@ -125,9 +125,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -125,9 +125,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
} }
}; };
/// NOTICE: we no-longer use this policy.
template <> template <>
struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
{ {
static constexpr bool QLoadOnce = false; static constexpr bool QLoadOnce = false;
......
...@@ -43,7 +43,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */, ...@@ -43,7 +43,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kDoFp8StaticQuant_, bool kDoFp8StaticQuant_,
bool kIsPagedKV_, bool kIsPagedKV_,
bool kHasUnevenSplits_, bool kHasUnevenSplits_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */> bool kMergeNumHeadGroupsSeqLenQ_ = false,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdSplitKVTraits struct TileFmhaFwdSplitKVTraits
{ {
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
...@@ -56,8 +57,9 @@ struct TileFmhaFwdSplitKVTraits ...@@ -56,8 +57,9 @@ struct TileFmhaFwdSplitKVTraits
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kIsPagedKV = kIsPagedKV_; static constexpr bool kIsPagedKV = kIsPagedKV_;
// determine if some split (length) is not divisible by tile size // determine if some split (length) is not divisible by tile size
static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ_ /* padding for seqlen_q */, template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -15,6 +15,7 @@ struct Layernorm2dFwdHostArgs ...@@ -15,6 +15,7 @@ struct Layernorm2dFwdHostArgs
const void* p_x; // [m ,n], input, fp16/bf16 const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_x_bias; // [1, n], bias, prec same as input
const void* p_gamma; // [1, n], gamma, prec same as input const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input const void* p_beta; // [1, n], beta, prec same as input
...@@ -43,6 +44,7 @@ struct Layernorm2dFwd ...@@ -43,6 +44,7 @@ struct Layernorm2dFwd
using Problem = typename Pipeline::Problem; using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>; using XDataType = remove_cvref_t<typename Problem::XDataType>;
using XBiasDataType = remove_cvref_t<typename Problem::XBiasDataType>;
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>; using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>; using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
...@@ -67,6 +69,7 @@ struct Layernorm2dFwd ...@@ -67,6 +69,7 @@ struct Layernorm2dFwd
static constexpr bool kPadM = false; // always no need to pad along M static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kTwoPass = Problem::Traits::kTwoPass; static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
static constexpr auto kXbias = Problem::Traits::kXbias;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
...@@ -82,6 +85,7 @@ struct Layernorm2dFwd ...@@ -82,6 +85,7 @@ struct Layernorm2dFwd
const void* p_x; // [m ,n], input, fp16/bf16 const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_x_bias; // [1, n], bias, prec same as input
const void* p_gamma; // [1, n], gamma, prec same as input const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input const void* p_beta; // [1, n], beta, prec same as input
...@@ -108,6 +112,7 @@ struct Layernorm2dFwd ...@@ -108,6 +112,7 @@ struct Layernorm2dFwd
return Kargs{hargs.p_x, return Kargs{hargs.p_x,
hargs.p_x_residual, hargs.p_x_residual,
hargs.p_x_scale, hargs.p_x_scale,
hargs.p_x_bias,
hargs.p_gamma, hargs.p_gamma,
hargs.p_beta, hargs.p_beta,
hargs.p_y, hargs.p_y,
...@@ -152,6 +157,7 @@ struct Layernorm2dFwd ...@@ -152,6 +157,7 @@ struct Layernorm2dFwd
using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
auto surfix = [&] () { auto surfix = [&] () {
std::string n; std::string n;
if (kXbias != Layernorm2dXBiasEnum::NO_BIAS) n += _SS_("_") + Layernorm2dXBiasEnumName<kXbias>::name;
if (kFusedAdd != Layernorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Layernorm2dFusedAddEnumName<kFusedAdd>::name; if (kFusedAdd != Layernorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Layernorm2dFusedAddEnumName<kFusedAdd>::name;
if (kFusedQuant != Layernorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Layernorm2dFusedQuantEnumName<kFusedQuant>::name; if (kFusedQuant != Layernorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Layernorm2dFusedQuantEnumName<kFusedQuant>::name;
if (kPadN) n += "_pn"; if (kPadN) n += "_pn";
...@@ -228,6 +234,27 @@ struct Layernorm2dFwd ...@@ -228,6 +234,27 @@ struct Layernorm2dFwd
} }
}(); }();
const auto x_bias_window = [&]() {
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
{
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XBiasDataType*>(kargs.p_x_bias),
make_tuple(kargs.n),
make_tuple(1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_N>{}));
}
}();
const auto gamma_window = [&]() { const auto gamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma), static_cast<const GammaDataType*>(kargs.p_gamma),
...@@ -371,6 +398,7 @@ struct Layernorm2dFwd ...@@ -371,6 +398,7 @@ struct Layernorm2dFwd
Pipeline{}(x_window, Pipeline{}(x_window,
x_residual_window, x_residual_window,
x_bias_window,
gamma_window, gamma_window,
beta_window, beta_window,
y_window, y_window,
......
...@@ -18,6 +18,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -18,6 +18,7 @@ struct Layernorm2dFwdPipelineOnePass
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>; using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using XBiasDataType = ck_tile::remove_cvref_t<typename Problem::XBiasDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>; using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>; using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
...@@ -38,6 +39,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -38,6 +39,7 @@ struct Layernorm2dFwdPipelineOnePass
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv; static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford; static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kXbias = Problem::Traits::kXbias;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
...@@ -55,6 +57,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -55,6 +57,7 @@ struct Layernorm2dFwdPipelineOnePass
template <typename XWindow, template <typename XWindow,
typename XResidualWindow, typename XResidualWindow,
typename XBiasWindow,
typename GammaWindow, typename GammaWindow,
typename BetaWindow, typename BetaWindow,
typename YWindow, typename YWindow,
...@@ -66,6 +69,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -66,6 +69,7 @@ struct Layernorm2dFwdPipelineOnePass
typename Epilogue> typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_, const XResidualWindow& x_residual_window_,
const XBiasWindow& x_bias_window_,
const GammaWindow& gamma_window_, const GammaWindow& gamma_window_,
const BetaWindow& beta_window_, const BetaWindow& beta_window_,
YWindow& y_window_, YWindow& y_window_,
...@@ -81,6 +85,8 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -81,6 +85,8 @@ struct Layernorm2dFwdPipelineOnePass
{ {
const auto x_window = const auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto x_bias_window = make_tile_window(
x_bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto gamma_window = make_tile_window( const auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>()); gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto beta_window = make_tile_window( const auto beta_window = make_tile_window(
...@@ -90,8 +96,9 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -90,8 +96,9 @@ struct Layernorm2dFwdPipelineOnePass
auto y_residual_window = make_tile_window( auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>()); y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
const auto x_bias = load_tile(x_bias_window);
int cur_count = 0; int cur_count = 0;
int max_count = int max_count =
...@@ -112,6 +119,15 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -112,6 +119,15 @@ struct Layernorm2dFwdPipelineOnePass
auto acc = cast_tile<ComputeDataType>(x); auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
{
sweep_tile(x, [&](auto idx) {
// compute x = bias + x
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
});
}
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{ {
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
namespace ck_tile { namespace ck_tile {
template <typename XDataType_, template <typename XDataType_,
typename XBiasDataType_,
typename GammaDataType_, typename GammaDataType_,
typename BetaDataType_, typename BetaDataType_,
typename ComputeDataType_, typename ComputeDataType_,
...@@ -21,6 +22,7 @@ template <typename XDataType_, ...@@ -21,6 +22,7 @@ template <typename XDataType_,
struct Layernorm2dFwdPipelineProblem struct Layernorm2dFwdPipelineProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using XBiasDataType = remove_cvref_t<XBiasDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>; using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>; using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>; using ComputeDataType = remove_cvref_t<ComputeDataType_>;
......
...@@ -17,6 +17,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -17,6 +17,7 @@ struct Layernorm2dFwdPipelineTwoPass
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>; using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using XBiasDataType = ck_tile::remove_cvref_t<typename Problem::XBiasDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>; using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>; using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
...@@ -37,6 +38,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -37,6 +38,7 @@ struct Layernorm2dFwdPipelineTwoPass
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv; static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford; static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kXbias = Problem::Traits::kXbias;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
...@@ -54,6 +56,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -54,6 +56,7 @@ struct Layernorm2dFwdPipelineTwoPass
template <typename XWindow, template <typename XWindow,
typename XResidualWindow, typename XResidualWindow,
typename XBiasWindow,
typename GammaWindow, typename GammaWindow,
typename BetaWindow, typename BetaWindow,
typename YWindow, typename YWindow,
...@@ -65,6 +68,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -65,6 +68,7 @@ struct Layernorm2dFwdPipelineTwoPass
typename Epilogue> typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_, const XResidualWindow& x_residual_window_,
const XBiasWindow& x_bias_window_,
const GammaWindow& gamma_window_, const GammaWindow& gamma_window_,
const BetaWindow& beta_window_, const BetaWindow& beta_window_,
YWindow& y_window, YWindow& y_window,
...@@ -81,6 +85,8 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -81,6 +85,8 @@ struct Layernorm2dFwdPipelineTwoPass
static_assert(kWelford == true, "2 pass only supports welford merge"); static_assert(kWelford == true, "2 pass only supports welford merge");
auto x_window = auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto x_bias_window = make_tile_window(
x_bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto gamma_window = make_tile_window( auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>()); gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto beta_window = make_tile_window( auto beta_window = make_tile_window(
...@@ -115,13 +121,24 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -115,13 +121,24 @@ struct Layernorm2dFwdPipelineTwoPass
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
const auto x_bias = load_tile(x_bias_window);
move_tile_window(x_window, {0, Block_N}); move_tile_window(x_window, {0, Block_N});
move_tile_window(x_residual_window, {0, Block_N}); move_tile_window(x_residual_window, {0, Block_N});
move_tile_window(x_bias_window, {Block_N});
auto acc = cast_tile<ComputeDataType>(x); auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
{
sweep_tile(x, [&](auto idx) {
// compute x = bias + x
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
});
}
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{ {
...@@ -167,6 +184,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -167,6 +184,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N}); move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(x_bias_window, {-Block_N});
move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(beta_window, {stride_to_right_most_window}); move_tile_window(beta_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window}); move_tile_window(y_window, {0, stride_to_right_most_window});
...@@ -174,9 +192,19 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -174,9 +192,19 @@ struct Layernorm2dFwdPipelineTwoPass
// layernorm computation // layernorm computation
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
auto acc = cast_tile<ComputeDataType>(x); const auto x_bias = load_tile(x_bias_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
{
sweep_tile(x, [&](auto idx) {
// compute x = bias + x
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
});
}
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
...@@ -209,6 +237,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -209,6 +237,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N}); move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(x_bias_window, {-Block_N});
move_tile_window(gamma_window, {-Block_N}); move_tile_window(gamma_window, {-Block_N});
move_tile_window(beta_window, {-Block_N}); move_tile_window(beta_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N}); move_tile_window(y_window, {0, -Block_N});
......
...@@ -7,6 +7,19 @@ ...@@ -7,6 +7,19 @@
namespace ck_tile { namespace ck_tile {
enum class Layernorm2dXBiasEnum
{
NO_BIAS = 0,
// add bias before fused add
ADD_BIAS = 1,
};
// clang-format off
template<Layernorm2dXBiasEnum> struct Layernorm2dXBiasEnumName;
template<> struct Layernorm2dXBiasEnumName<Layernorm2dXBiasEnum::NO_BIAS> { static constexpr const char * name = "no"; };
template<> struct Layernorm2dXBiasEnumName<Layernorm2dXBiasEnum::ADD_BIAS> { static constexpr const char * name = "xbias"; };
// clang-format on
enum class Layernorm2dFusedAddEnum enum class Layernorm2dFusedAddEnum
{ {
NO_ADD = 0, NO_ADD = 0,
...@@ -42,6 +55,7 @@ template <bool kPadN_, ...@@ -42,6 +55,7 @@ template <bool kPadN_,
bool kFastFDiv_, bool kFastFDiv_,
bool kWelford_, bool kWelford_,
bool kTwoPass_, bool kTwoPass_,
Layernorm2dXBiasEnum kXbias_,
Layernorm2dFusedAddEnum kFusedAdd_, Layernorm2dFusedAddEnum kFusedAdd_,
Layernorm2dFusedQuantEnum kFusedQuant_> Layernorm2dFusedQuantEnum kFusedQuant_>
struct Layernorm2dFwdTraits struct Layernorm2dFwdTraits
...@@ -51,6 +65,7 @@ struct Layernorm2dFwdTraits ...@@ -51,6 +65,7 @@ struct Layernorm2dFwdTraits
static constexpr bool kFastFDiv = kFastFDiv_; static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_; static constexpr bool kWelford = kWelford_;
static constexpr bool kTwoPass = kTwoPass_; static constexpr bool kTwoPass = kTwoPass_;
static constexpr Layernorm2dXBiasEnum kXbias = kXbias_;
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_; static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_; static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment