Commit ea5be216 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents e2eb0418 25935b57
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockFmhaShape_>
struct FmhaBwdTilePartitioner
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_k_, kN0), nhead_, batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/)
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
template <ck_tile::index_t kBlockSize>
struct FmhaBwdOGradDotOTilePartitioner
{
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kBlockSize), nhead_, batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/)
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
} // namespace ck_tile
...@@ -86,7 +86,7 @@ struct FmhaFwdKernel ...@@ -86,7 +86,7 @@ struct FmhaFwdKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
...@@ -387,7 +387,6 @@ struct FmhaFwdKernel ...@@ -387,7 +387,6 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
...@@ -448,7 +447,6 @@ struct FmhaFwdKernel ...@@ -448,7 +447,6 @@ struct FmhaFwdKernel
{ {
kargs.lse_ptr = lse_ptr; kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse; kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
} }
if constexpr(kDoFp8StaticQuant) if constexpr(kDoFp8StaticQuant)
{ {
...@@ -524,7 +522,7 @@ struct FmhaFwdKernel ...@@ -524,7 +522,7 @@ struct FmhaFwdKernel
} }
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse; batch_offset_lse = query_start;
} }
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
......
...@@ -55,7 +55,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -55,7 +55,7 @@ struct FmhaFwdSplitKVCombineKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
_SS_(FmhaPipeline::name) + _SS_(FmhaPipeline::name) +
(pn.empty() ? "" : "_" + pn) + (pn.empty() ? "" : "_" + pn) +
(kStoreLSE ? "_lse" : "" ) + (kStoreLSE ? "_lse" : "" ) +
(kDoFp8StaticQuant ? "_squant" : "" ); (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
...@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o; ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc; ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
...@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>> std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
{ {
ck_tile::index_t batch_stride_o; ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_lse_acc;
}; };
struct GroupModeKargs struct GroupModeKargs
...@@ -166,13 +166,13 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -166,13 +166,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
nhead_stride_o, nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for lse {}, // placeholder for lse
{}, // placeholder for fp8_static_quant args {}, // placeholder for fp8_static_quant args
batch_stride_o}; batch_stride_o,
batch_stride_lse_acc};
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
...@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc) ck_tile::index_t split_stride_o_acc)
{ {
...@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
nhead_stride_o, nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
...@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel
{ {
kargs.lse_ptr = lse_ptr; kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse; kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
} }
if constexpr(kDoFp8StaticQuant) if constexpr(kDoFp8StaticQuant)
{ {
...@@ -274,24 +270,25 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -274,24 +270,25 @@ struct FmhaFwdSplitKVCombineKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
const long_index_t batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
const long_index_t batch_offset_o_acc = const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc; static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
if constexpr(kStoreLSE) long_index_t batch_offset_lse_acc = 0;
{ long_index_t batch_offset_lse = 0;
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse; long_index_t batch_offset_o = 0;
}
if constexpr(kIsGroupMode) if constexpr(kIsGroupMode)
{ {
// get starting offset for each batch // get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_o = query_start * kargs.row_stride_o; batch_offset_o = query_start * kargs.row_stride_o;
batch_offset_lse_acc = query_start;
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
}
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
...@@ -306,7 +303,13 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -306,7 +303,13 @@ struct FmhaFwdSplitKVCombineKernel
} }
else else
{ {
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o; batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
} }
// for simplicity, batch stride we just modify the pointer // for simplicity, batch stride we just modify the pointer
......
...@@ -85,7 +85,7 @@ struct FmhaFwdSplitKVKernel ...@@ -85,7 +85,7 @@ struct FmhaFwdSplitKVKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
...@@ -136,7 +136,6 @@ struct FmhaFwdSplitKVKernel ...@@ -136,7 +136,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_lse_acc; ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc; ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
...@@ -216,6 +215,7 @@ struct FmhaFwdSplitKVKernel ...@@ -216,6 +215,7 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_lse_acc;
}; };
struct GroupModeKargs struct GroupModeKargs
...@@ -313,7 +313,6 @@ struct FmhaFwdSplitKVKernel ...@@ -313,7 +313,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
...@@ -323,7 +322,8 @@ struct FmhaFwdSplitKVKernel ...@@ -323,7 +322,8 @@ struct FmhaFwdSplitKVKernel
{}, // placeholder for dropout {}, // placeholder for dropout
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
batch_stride_v}; batch_stride_v,
batch_stride_lse_acc};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -394,7 +394,6 @@ struct FmhaFwdSplitKVKernel ...@@ -394,7 +394,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc, ck_tile::index_t split_stride_o_acc,
...@@ -433,7 +432,6 @@ struct FmhaFwdSplitKVKernel ...@@ -433,7 +432,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
...@@ -511,8 +509,7 @@ struct FmhaFwdSplitKVKernel ...@@ -511,8 +509,7 @@ struct FmhaFwdSplitKVKernel
long_index_t batch_offset_v = 0; long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0; long_index_t batch_offset_bias = 0;
long_index_t batch_offset_randval = 0; long_index_t batch_offset_randval = 0;
const long_index_t batch_offset_lse_acc = long_index_t batch_offset_lse_acc = 0;
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
const long_index_t batch_offset_o_acc = const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc; static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
...@@ -522,8 +519,9 @@ struct FmhaFwdSplitKVKernel ...@@ -522,8 +519,9 @@ struct FmhaFwdSplitKVKernel
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q; batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k; batch_offset_k = key_start * kargs.stride_k;
batch_offset_lse_acc = query_start;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
batch_offset_v = key_start * kargs.stride_v; batch_offset_v = key_start * kargs.stride_v;
...@@ -564,9 +562,10 @@ struct FmhaFwdSplitKVKernel ...@@ -564,9 +562,10 @@ struct FmhaFwdSplitKVKernel
} }
else else
{ {
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q; batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k; batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v; batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias; batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
struct BlockFmhaBwdConvertQGrad
{
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
static constexpr index_t kM0 = Problem::kM0;
static constexpr index_t kN0 = Problem::kN0;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kQKHeaddim = Problem::kQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr index_t kAlignmentQGradAcc =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGradAcc<Problem>();
static constexpr index_t kAlignmentQGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGrad<Problem>();
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
// Convert only
template <typename QGradAccDramBlockWindowTmp, typename QGradDramBlockWindowTmp>
CK_TILE_HOST_DEVICE void
operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp,
QGradDramBlockWindowTmp& dq_dram_block_window_tmp) const
{
static_assert(
std::is_same_v<AccDataType,
remove_cvref_t<typename QGradAccDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
auto dq_acc_dram_window =
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
dq_acc_dram_block_window_tmp.get_window_lengths(),
dq_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakePostQGradDramTileDistribution<Problem>());
auto dq_acc = load_tile(dq_acc_dram_window);
const auto dq = cast_tile<QGradDataType>(dq_acc);
store_tile(dq_dram_block_window_tmp, dq);
}
// Reduce + Convert
template <typename QGradAccDramBlockWindowTmp, typename QGradDramBlockWindowTmp>
CK_TILE_HOST_DEVICE void
operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp,
QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
index_t nsplits) const
{
static_assert(
std::is_same_v<AccDataType,
remove_cvref_t<typename QGradAccDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
auto dq_acc_dram_window =
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
dq_acc_dram_block_window_tmp.get_window_lengths(),
dq_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakePostQGradAccDramTileDistribution<Problem>());
auto dq_acc = decltype(load_tile(dq_acc_dram_window)){};
clear_tile(dq_acc);
constexpr auto dq_acc_spans = decltype(dq_acc)::get_distributed_spans();
index_t i_total_loops = 0;
auto dq_acc_buf = load_tile(dq_acc_dram_window);
move_tile_window(dq_acc_dram_window, {1, 0, 0});
do
{
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx);
});
});
});
dq_acc_buf = load_tile(dq_acc_dram_window);
move_tile_window(dq_acc_dram_window, {1, 0, 0});
i_total_loops += 1;
} while(i_total_loops < (nsplits - 1));
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx);
});
});
});
// declare dq
constexpr auto dq_converted_dstr =
Policy::template MakePostQGradAccDramTileDistribution<Problem>();
auto dq_converted = make_static_distributed_tensor<QGradDataType>(dq_converted_dstr);
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
dq_converted(n_i_j_idx) = type_convert<QGradDataType>(dq_acc[n_i_j_idx]);
});
});
});
constexpr auto dq_dstr = Policy::template MakePostQGradDramTileDistribution<Problem>();
auto dq = make_static_distributed_tensor<QGradDataType>(dq_dstr);
dq.get_thread_buffer() = dq_converted.get_thread_buffer();
store_tile(dq_dram_block_window_tmp, dq);
}
};
} // namespace ck_tile
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile { namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdOGradDotODefaultPolicy> template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
struct BlockFmhaBwdOGradDotO struct BlockFmhaBwdOGradDotO
{ {
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
...@@ -26,7 +26,7 @@ struct BlockFmhaBwdOGradDotO ...@@ -26,7 +26,7 @@ struct BlockFmhaBwdOGradDotO
static constexpr index_t kAlignmentO = static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad = static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; } CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// These templates are not used here.
using BlockFmhaBwdOGradDotODefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ false,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ false,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
...@@ -6,13 +6,13 @@ ...@@ -6,13 +6,13 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile { namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy> template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
{ {
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
...@@ -30,6 +30,8 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -30,6 +30,8 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>; using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>; using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using HotLoopScheduler = typename Policy::template HotLoopScheduler<Problem>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
...@@ -46,22 +48,14 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -46,22 +48,14 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kQLoadOnce = true; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kQTLoadOnce = false; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kKLoadOnce = true; static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kKTLoadOnce = false; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kVLoadOnce = true; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kOGradLoadOnce = true; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kOGradTLoadOnce = false; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this // ... together with tensor distribution. tensor dist should able to overwrite this
...@@ -71,12 +65,9 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -71,12 +65,9 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>(); kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad = static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad = static constexpr index_t kAlignmentQGrad = 1;
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
static constexpr index_t kAlignmentKGrad = static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>(); kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad = static constexpr index_t kAlignmentVGrad =
...@@ -84,7 +75,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -84,7 +75,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
static constexpr index_t kAlignmentBias = static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>(); kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr const char* name = "qs_ks_vr_dos"; static constexpr const char* name = "kr_ktr_vr";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
...@@ -92,14 +83,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -92,14 +83,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
} }
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename QTDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename KTDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp, typename RandValDramBlockWindowTmp,
typename OGradDramBlockWindowTmp, typename OGradDramBlockWindowTmp,
typename OGradTDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp, typename DDramBlockWindowTmp,
typename QGradDramBlockWindowTmp, typename QGradDramBlockWindowTmp,
...@@ -107,14 +95,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -107,14 +95,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
const QTDramBlockWindowTmp& /*qt_dram_block_window_tmp*/,
const KDramBlockWindowTmp& k_dram_block_window_tmp, const KDramBlockWindowTmp& k_dram_block_window_tmp,
const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/,
const VDramBlockWindowTmp& v_dram_block_window_tmp, const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp, const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const OGradTDramBlockWindowTmp& /*dot_dram_block_window_tmp*/,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
const DDramBlockWindowTmp& d_dram_block_window_tmp, const DDramBlockWindowTmp& d_dram_block_window_tmp,
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
...@@ -122,13 +107,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -122,13 +107,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float raw_scale, float raw_scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float scale, float scale,
#endif
float rp_undrop, float rp_undrop,
float scale_rp_undrop, float scale_rp_undrop,
void* smem_ptr, void* smem_ptr,
BlockDropout& dropout) const FmhaDropout& dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
...@@ -138,9 +121,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -138,9 +121,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> && remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<LSEDataType, std::is_same_v<LSEDataType,
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> && remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> && std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!"); "wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
...@@ -156,77 +137,6 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -156,77 +137,6 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); "wrong!");
// Q tile in LDS
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
// QT tile in LDS
auto qt_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptorAsQT<Problem>());
auto qt_lds_window =
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kM0>{}), {0, 0});
// K tile in LDS
auto k_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
// KT tile in LDS
auto kt_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptorAsKT<Problem>());
auto kt_lds_window =
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
// OGrad tile in LDS
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeQ<Problem>()));
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
// OGradT tile in LDS
auto dot_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptorAsOGradT<Problem>());
auto dot_lds_window =
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kM0>{}), {0, 0});
// SGrad tile in LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>()));
auto biast_lds = make_tensor_view<address_space_enum::lds>(
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
auto biast_lds_shuffle_window =
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto dbiast_lds_shuffle_window =
make_tile_window(biast_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// Block GEMM // Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>(); constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>(); constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
...@@ -234,34 +144,19 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -234,34 +144,19 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>(); constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>(); constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
auto v = load_tile(v_dram_window); // persistent V register tile
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
// init VGrad & KGrad // init VGrad & KGrad
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
clear_tile(dv_acc); // K, HBM ->LDS ->Reg
clear_tile(dk_acc); auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
auto k_dram_window = make_tile_window( k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_origin(),
k_dram_block_window_tmp.get_window_lengths(), Policy::template MakeKDramTileDistribution<Problem>());
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier(0);
const auto k_origin = k_dram_window.get_window_origin(); const auto k_origin = k_dram_window.get_window_origin();
// Early termination
const auto [seqlen_q_start, seqlen_q_end] = const auto [seqlen_q_start, seqlen_q_end] =
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}); mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
...@@ -274,217 +169,408 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -274,217 +169,408 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
{ {
// Note: here dk_acc&dv_acc are all cleard, return it // Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it. // Note: v loaded but no fence, ignore it.
return ck_tile::make_tuple(dk_acc, dv_acc); return make_tuple(dk_acc, dv_acc);
} }
} }
KDataType* k_lds_ptr =
static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
auto k_lds_read_window =
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK0>{}),
k_lds_write_window.get_window_origin(),
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
Policy::template MakeKRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// V, HBM ->LDS ->Reg
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVDramTileDistribution<Problem>());
VDataType* v_lds_ptr =
static_cast<VDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto v_lds = make_tensor_view<address_space_enum::lds>(
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
auto v_lds_write_window =
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
auto v_lds_read_window =
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK2>{}),
v_lds_write_window.get_window_origin(),
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
Policy::template MakeVRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
auto shuffled_k_block_tile = make_static_distributed_tensor<KDataType>(
Policy::template MakeShuffledKRegWriteBlockDescriptor<Problem>());
KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto shuffled_k_lds_write = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
auto shuffled_k_lds_write_window = make_tile_window(
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
auto kt_lds_read_window =
make_tile_window(kt_lds_read,
make_tuple(number<kQKHeaddim>{}, number<kN0>{}),
{0, 0},
Policy::template MakeKTRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// Pre-Load KV into Registers
auto k_block_tile = load_tile(k_dram_window); auto k_block_tile = load_tile(k_dram_window);
auto v_block_tile = load_tile(v_dram_window);
store_tile(k_lds_write_window, k_block_tile);
shuffle_tile(shuffled_k_block_tile, k_block_tile);
store_tile(shuffled_k_lds_write_window, shuffled_k_block_tile);
block_sync_lds();
k_reg_tensor = load_tile(k_lds_read_window);
block_sync_lds();
auto kt_reg_tensor = load_tile(kt_lds_read_window);
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS store_tile(v_lds_write_window, v_block_tile);
auto q_dram_block_window = block_sync_lds();
v_reg_tensor = load_tile(v_lds_read_window);
block_sync_lds();
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0}); {seqlen_q_start, 0},
Policy::template MakeQDramTileDistribution<Problem>());
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
auto q_lds_read_window =
make_tile_window(q_lds_window.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kK0>{}),
q_lds_window.get_window_origin(),
Policy::template MakeQRegSliceBlockDescriptor<Problem>());
auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
Policy::template MakePTRegSliceBlockDescriptor<Problem>());
// QT: Reg -> Reg-> LDS
auto shuffled_q_block_tile = make_static_distributed_tensor<QDataType>(
Policy::template MakeShuffledQRegWriteBlockDescriptor<Problem>());
auto do_dram_block_window = QDataType* qt_lds_ptr =
static_cast<QDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto shuffled_q_lds_write = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
auto shuffled_q_lds_write_window = make_tile_window(
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
auto qt_lds_read_window =
make_tile_window(qt_lds_read,
make_tuple(number<kQKHeaddim>{}, number<kM0>{}),
{0, 0},
Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
// dO: HBM ->Reg ->LDS
auto do_dram_window =
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
do_dram_block_window_tmp.get_window_lengths(), do_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0}); {seqlen_q_start, 0},
Policy::template MakeOGradDramTileDistribution<Problem>());
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>()));
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
auto do_lds_read_window =
make_tile_window(do_lds_window.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kK2>{}),
do_lds_window.get_window_origin(),
Policy::template MakeOGradRegSliceBlockDescriptor<Problem>());
// dOT: Reg ->Reg ->LDS
auto shuffled_do_block_tile = make_static_distributed_tensor<OGradDataType>(
Policy::template MakeShuffledOGradRegWriteBlockDescriptor<Problem>());
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>()));
auto dq_dram_block_window = auto shuffled_do_lds_write = make_tensor_view<address_space_enum::lds>(
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto lse_dram_block_window = auto shuffled_do_lds_write_window = make_tile_window(
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
auto d_dram_block_window = auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
auto dot_lds_read_window =
make_tile_window(dot_read_lds,
make_tuple(number<kVHeaddim>{}, number<kM0>{}),
{0, 0},
Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
// dS: Reg -> Reg -> LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeD<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto ds_lds_read_window =
make_tile_window(ds_lds_window.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kK4>{}),
ds_lds_window.get_window_origin(),
Policy::template MakeSGradRegSliceBlockDescriptor<Problem>());
auto dst_reg_tensor = make_static_distributed_tensor<GemmDataType>(
Policy::template MakeSGradTRegSliceBlockDescriptor<Problem>());
// Bias: HBM ->Reg ->Reg ->LDS
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_block_window =
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(), bias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N {seqlen_q_start, bias_origin.at(number<1>{})},
Policy::template MakeBiasTileDistribution<Problem>());
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); BiasDataType* bias_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
auto dbias_dram_block_window = static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), Policy::template GetSmemSizeOGrad<Problem>() +
dbias_dram_block_window_tmp.get_window_lengths(), Policy::template GetSmemSizeOGradT<Problem>() +
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeD<Problem>()));
auto bias_lds = make_tensor_view<address_space_enum::lds>(
bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor<Problem>());
auto bias_lds_write_window =
make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto bias_s_lds_read_window =
make_tile_window(bias_lds_write_window.get_bottom_tensor_view(),
bias_lds_write_window.get_window_lengths(),
bias_lds_write_window.get_window_origin(),
Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// LSE: HBM -> LDS ->Reg
auto lse_dram_window = make_tile_window( auto lse_dram_window = make_tile_window(
lse_dram_block_window.get_bottom_tensor_view(), lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window.get_window_lengths(), lse_dram_block_window_tmp.get_window_lengths(),
lse_dram_block_window.get_window_origin(), {seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
LSEDataType* lse_lds_ptr = static_cast<LSEDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template GetSmemSizeQ<Problem>()));
auto lse_lds = make_tensor_view<address_space_enum::lds>(
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
auto lse_lds_read_window = make_tile_window(
lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
// D: HBM ->Reg
auto d_dram_window = make_tile_window( auto d_dram_window = make_tile_window(
d_dram_block_window.get_bottom_tensor_view(), d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window.get_window_lengths(), d_dram_block_window_tmp.get_window_lengths(),
d_dram_block_window.get_window_origin(), {seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window = DDataType* d_lds_ptr = static_cast<DDataType*>(static_cast<void*>(
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
bias_dram_block_window.get_window_lengths(), Policy::template GetSmemSizeOGrad<Problem>() +
bias_dram_block_window.get_window_origin(), Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template MakeBiasTileDistribution<Problem>()); Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>()));
auto d_lds = make_tensor_view<address_space_enum::lds>(
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
auto biast_lds_window = auto d_lds_read_window = make_tile_window(
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), d_lds,
biast_lds_shuffle_window.get_window_lengths(), make_tuple(number<kM0>{}),
biast_lds_shuffle_window.get_window_origin(), {0},
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>()); Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>( // RandVal: HBM ->Reg
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
randval_dram_block_window_tmp, seqlen_q_start); randval_dram_block_window_tmp, seqlen_q_start);
index_t i_total_loops = 0; // BiasGrad
constexpr index_t k0_loops = kQKHeaddim / kK0; // Reg ->LDS ->Reg ->HBM
constexpr index_t k1_loops = kM0 / kK1; const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
constexpr index_t k2_loops = kVHeaddim / kK2;
constexpr index_t k3_loops = kM0 / kK3; auto dbias_dram_window =
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
auto dbias_lds_read_window =
make_tile_window(bias_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
// ----------------------------Loop write out------------------------------//
auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
using SPBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
index_t i_total_loops = 0;
index_t seqlen_q_step = seqlen_q_start;
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
static_assert(kM0 == kK1, "kM0 should equal to kK1");
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4; constexpr index_t k4_loops = kN0 / kK4;
do
{
auto q_dram_window = make_tile_window(
q_dram_block_window.get_bottom_tensor_view(),
q_dram_block_window.get_window_lengths(),
q_dram_block_window.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
// load
auto do_dram_window = make_tile_window(
do_dram_block_window.get_bottom_tensor_view(),
do_dram_block_window.get_window_lengths(),
do_dram_block_window.get_window_origin(),
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0 clear_tile(dv_acc);
auto st_acc = SPTBlockTileType{}; clear_tile(dk_acc);
__builtin_amdgcn_sched_barrier(0);
// Hot loop
while(i_total_loops < num_total_loop)
{
auto q_block_tile = load_tile(q_dram_window); auto q_block_tile = load_tile(q_dram_window);
clear_tile(st_acc); // Initialize S^T move_tile_window(q_dram_window, {kM0, 0});
store_tile(q_lds_window, q_block_tile); // LDS write
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) auto lse_block_tile = load_tile(lse_dram_window);
{ move_tile_window(lse_dram_window, {kM0});
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 1) store_tile(q_lds_window, q_block_tile);
{ shuffle_tile(shuffled_q_block_tile, q_block_tile);
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { store_tile(shuffled_q_lds_write_window, shuffled_q_block_tile);
block_sync_lds();
gemm_0(st_acc,
get_slice_tile(q_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}),
get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
block_sync_lds();
});
}
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile store_tile(lse_lds_write_window, lse_block_tile);
{ // tail
block_sync_lds(); block_sync_lds();
gemm_0(st_acc,
get_slice_tile(q_lds_window, auto q_reg_tensor = load_tile(q_lds_read_window);
sequence<0, (k0_loops - 1) * kK0>{}, auto lse = load_tile(lse_lds_read_window);
sequence<kM0, k0_loops * kK0>{}),
get_slice_tile(k_lds_window, block_sync_lds();
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kN0, k0_loops * kK0>{})); // STAGE 1, Q@K Gemm0
block_sync_lds(); auto s_acc = SPBlockTileType{};
}
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
block_sync_lds(); const auto bias_tile = load_tile(bias_dram_window);
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>( auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>()); Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile); shuffle_tile(shuffled_bias_tile, bias_tile);
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); store_tile(bias_lds_write_window, shuffled_bias_tile);
block_sync_lds(); block_sync_lds();
auto biast_tile = load_tile(biast_lds_window); auto bias_s_tile = load_tile(bias_s_lds_read_window);
tile_elementwise_inout( tile_elementwise_inout(
[&](auto& x, const auto& y) { [&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = raw_scale * x + type_convert<AccDataType>(y);
#else
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y); x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
#endif
}, },
st_acc, s_acc,
biast_tile); bias_s_tile);
move_tile_window(bias_dram_window, {kM0, 0}); move_tile_window(bias_dram_window, {kM0, 0});
__builtin_amdgcn_sched_barrier(0);
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
const auto q_origin = q_dram_block_window.get_window_origin(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices( const auto tile_idx = get_x_indices_from_distributed_indices(
st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if !CK_TILE_FMHA_FWD_FAST_EXP2 s_acc(i_j_idx) *= scale;
st_acc(i_j_idx) *= raw_scale; position_encoding.update(s_acc(i_j_idx), row, col);
#else
st_acc(i_j_idx) *= scale;
#endif
position_encoding.update(st_acc(i_j_idx), row, col);
}); });
}); });
} }
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
#endif
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking) if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{ {
const auto q_origin = q_dram_block_window.get_window_origin(); bool need_perpixel_check = mask.IsEdgeTile(
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check) if(need_perpixel_check)
{ {
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) { set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col); return mask.IsOutOfBound(row, col);
}); });
} }
} }
const auto lse = load_tile(lse_dram_window);
static const auto get_validated_lse = [](LSEDataType raw_lse) { static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking) FmhaMask::IsMasking)
...@@ -499,157 +585,162 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -499,157 +585,162 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
} }
}; };
auto pt = SPTBlockTileType{}; auto p = SPBlockTileType{};
constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); constexpr auto p_spans = decltype(p)::get_distributed_spans();
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2 auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
#endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI) BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
} }
else else
{ {
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
} }
#else
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
#endif
}); });
}); });
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
dropout.Run<decltype(gemm_0), RandValOutputDataType>( dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window);
} }
const auto p_gemm = [&]() {
// STAGE 3, P^T@OGrad^T Gemm1 if constexpr(FmhaDropout::IsDropout)
block_sync_lds();
store_tile(do_lds_window, do_block_tile); // store the prefetch
const auto pt_gemm = [&]() {
if constexpr(kHasDropout)
{ {
return tile_elementwise_in( return tile_elementwise_in(
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); }, [](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt); p);
} }
else else
{ {
return cast_tile<GemmDataType>(pt); return cast_tile<GemmDataType>(p);
} }
}(); }();
static_for<0, k1_loops, 1>{}([&](auto i_k1) { // STAGE 3, P^T@OGrad^T Gemm1
block_sync_lds(); auto do_block_tile = load_tile(do_dram_window);
gemm_1(dv_acc, move_tile_window(do_dram_window, {kM0, 0});
get_slice_tile(
pt_gemm, sequence<i_k1 * kK1, 0>{}, sequence<(i_k1 + 1) * kK1, kN0>{}), auto d_block_tile = load_tile(d_dram_window);
get_slice_tile(dot_lds_window, move_tile_window(d_dram_window, {kM0});
sequence<0, i_k1 * kK1>{},
sequence<kVHeaddim, (i_k1 + 1) * kK1>{})); store_tile(do_lds_window, do_block_tile);
block_sync_lds(); shuffle_tile(shuffled_do_block_tile, do_block_tile);
}); store_tile(shuffled_do_lds_write_window, shuffled_do_block_tile);
store_tile(d_lds_write_window, d_block_tile);
block_sync_lds();
auto dot_reg_tensor = load_tile(dot_lds_read_window);
block_sync_lds();
Policy::template PTFromGemm0CToGemm1A<Problem,
decltype(pt_reg_tensor),
decltype(p_gemm)>(pt_reg_tensor, p_gemm);
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
// STAGE 4, OGrad@V Gemm2 // STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{}; auto do_reg_tensor = load_tile(do_lds_read_window);
clear_tile(dpt_acc); // Initialize PGrad^T auto d = load_tile(d_lds_read_window);
block_sync_lds();
static_for<0, k2_loops, 1>{}([&](auto i_k2) { auto dp_acc = SPGradBlockTileType{};
block_sync_lds();
gemm_2(dpt_acc,
get_slice_tile(do_lds_window,
sequence<0, i_k2 * kK2>{},
sequence<kM0, (i_k2 + 1) * kK2>{}),
get_slice_tile(
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
block_sync_lds();
});
// STAGE 5, P^T(PGrad^T - D) dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
const auto d = load_tile(d_dram_window);
auto dst = SPGradTBlockTileType{}; // STAGE 5, P^T(PGrad^T - D)
constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); auto ds = SPGradBlockTileType{};
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) { constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = pt[i_j_idx] >= 0; bool undrop_flag = p[i_j_idx] >= 0;
dst(i_j_idx) = ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
pt[i_j_idx] * ? (dp_acc[i_j_idx] - d[i_idx])
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); : d[i_idx]);
}); });
}); });
if constexpr(kHasBiasGrad) if constexpr(kHasBiasGrad)
{ {
const auto dbiast = [&]() { const auto dbias = [&]() {
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
return tile_elementwise_in( return tile_elementwise_in(
[&rp_undrop](const auto& x) { [&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop); return type_convert<BiasGradDataType>(x * rp_undrop);
}, },
dst); ds);
} }
else else
{ {
return cast_tile<BiasGradDataType>(dst); return cast_tile<BiasGradDataType>(ds);
} }
}(); }();
store_tile(biast_lds_shuffle_window, dbiast); store_tile(bias_lds_write_window, dbias);
block_sync_lds(); block_sync_lds();
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>( auto dbias_tile = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>()); Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); shuffle_tile(dbias_tile, shuffled_dbias_tile);
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); store_tile(dbias_dram_window, dbias_tile);
move_tile_window(dbias_dram_block_window, {kM0, 0}); move_tile_window(dbias_dram_window, {kM0, 0});
__builtin_amdgcn_sched_barrier(0);
} }
// STAGE 6, SGrad^T@Q^T Gemm3 // STAGE 6, SGrad^T@Q^T Gemm3
auto qt_reg_tensor = load_tile(qt_lds_read_window);
block_sync_lds(); block_sync_lds();
const auto dst_gemm = cast_tile<GemmDataType>(dst);
static_for<0, k3_loops, 1>{}([&](auto i_k3) { const auto ds_gemm = cast_tile<GemmDataType>(ds);
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(
dst_gemm, sequence<i_k3 * kK3, 0>{}, sequence<(i_k3 + 1) * kK3, kN0>{}),
get_slice_tile(qt_lds_window,
sequence<0, i_k3 * kK3>{},
sequence<kQKHeaddim, (i_k3 + 1) * kK3>{}));
block_sync_lds();
});
// STAGE 7, SGrad@K^T Gemm4 Policy::template SGradTFromGemm2CToGemm3A<Problem,
store_tile(ds_lds_window, dst_gemm); decltype(dst_reg_tensor),
decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
auto dq_acc = QGradBlockTileType{}; gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
clear_tile(dq_acc); // Initialize QGrad
store_tile(ds_lds_window, ds_gemm);
block_sync_lds(); block_sync_lds();
auto ds_reg_tensor = load_tile(ds_lds_read_window);
auto ds_reg_tensor_next = decltype(ds_reg_tensor){};
move_tile_window(ds_lds_read_window, {0, kK4});
// STAGE7 SGrad@K^T Gemm4
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc);
static_for<0, k4_loops, 1>{}([&](auto i_k4) { static_for<0, k4_loops, 1>{}([&](auto i_k4) {
gemm_4(dq_acc, if constexpr(i_k4 < k4_loops - 1)
get_slice_tile(ds_lds_window, {
sequence<0, i_k4 * kK4>{}, ds_reg_tensor_next = load_tile(ds_lds_read_window);
sequence<kM0, (i_k4 + 1) * kK4>{}), move_tile_window(ds_lds_read_window, {0, kK4});
get_slice_tile(kt_lds_window, }
sequence<0, i_k4 * kK4>{}, auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor,
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{})); sequence<0, i_k4 * kK4>{},
}); sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
}
});
move_tile_window(ds_lds_read_window, {0, -kN0});
// QGrad Scale // QGrad Scale
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc); dq_acc);
...@@ -658,34 +749,33 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -658,34 +749,33 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
{ {
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
} }
const auto dq = cast_tile<QGradDataType>(dq_acc); if constexpr(kIsDeterministic)
update_tile(dq_dram_block_window, dq); {
store_tile(dq_dram_window, dq_acc);
}
else
{
update_tile(dq_dram_window, dq_acc);
}
move_tile_window(dq_dram_window, {kM0, 0});
// move tile windows i_total_loops += 1;
move_tile_window(q_dram_block_window, {kM0, 0}); seqlen_q_step += kM0;
move_tile_window(dq_dram_block_window, {kM0, 0}); }
move_tile_window(do_dram_block_window, {kM0, 0});
move_tile_window(lse_dram_window, {kM0});
move_tile_window(d_dram_window, {kM0});
} while(++i_total_loops < num_total_loop);
// KGrad Scale // Results Scale
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc); dk_acc);
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
} }
else else
{ {
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
} }
// VGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
}
return ck_tile::make_tuple(dk_acc, dv_acc); return make_tuple(dk_acc, dv_acc);
} }
}; };
......
...@@ -6,13 +6,13 @@ ...@@ -6,13 +6,13 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile { namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy> template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
struct BlockFmhaBwdDQDKDVPipelineKSKTSVR struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
{ {
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
...@@ -30,6 +30,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -30,6 +30,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>; using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>; using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using HotLoopScheduler = typename Policy::template HotLoopScheduler<Problem>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
...@@ -46,22 +48,14 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -46,22 +48,14 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kQLoadOnce = false; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kQTLoadOnce = false; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kKLoadOnce = true; static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kKTLoadOnce = true; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kVLoadOnce = true; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kOGradLoadOnce = false; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kOGradTLoadOnce = false; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this // ... together with tensor distribution. tensor dist should able to overwrite this
...@@ -71,12 +65,9 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -71,12 +65,9 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>(); kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad = static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad = static constexpr index_t kAlignmentQGrad = 1;
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
static constexpr index_t kAlignmentKGrad = static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>(); kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad = static constexpr index_t kAlignmentVGrad =
...@@ -84,7 +75,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -84,7 +75,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
static constexpr index_t kAlignmentBias = static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>(); kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr const char* name = "ks_kts_vr"; static constexpr const char* name = "kr_ktr_vr_iglp";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
...@@ -92,14 +83,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -92,14 +83,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
} }
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename QTDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename KTDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp, typename RandValDramBlockWindowTmp,
typename OGradDramBlockWindowTmp, typename OGradDramBlockWindowTmp,
typename OGradTDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp, typename DDramBlockWindowTmp,
typename QGradDramBlockWindowTmp, typename QGradDramBlockWindowTmp,
...@@ -107,14 +95,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -107,14 +95,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp, const KDramBlockWindowTmp& k_dram_block_window_tmp,
const KTDramBlockWindowTmp& kt_dram_block_window_tmp,
const VDramBlockWindowTmp& v_dram_block_window_tmp, const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp, const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
const DDramBlockWindowTmp& d_dram_block_window_tmp, const DDramBlockWindowTmp& d_dram_block_window_tmp,
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
...@@ -122,43 +107,29 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -122,43 +107,29 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float raw_scale, float raw_scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float scale, float scale,
#endif
float rp_undrop, float rp_undrop,
float scale_rp_undrop, float scale_rp_undrop,
void* smem_ptr, void* smem_ptr,
BlockDropout& dropout) const FmhaDropout& dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<QDataType,
remove_cvref_t<typename QTDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> && std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType,
remove_cvref_t<typename KTDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> && std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType, std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> && remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradTDramBlockWindowTmp::DataType>> &&
std::is_same_v<LSEDataType, std::is_same_v<LSEDataType,
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> && remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> && std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!"); "wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == KTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kVHeaddim ==
OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
...@@ -166,83 +137,6 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -166,83 +137,6 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); "wrong!");
// Q tile in LDS
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
// QT tile in LDS
QDataType* qt_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto qt_lds = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor<Problem>());
auto qt_lds_window =
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kK3>{}), {0, 0});
// K tile in LDS
auto k_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
// KT tile in LDS
KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto kt_lds = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsBlockDescriptor<Problem>());
auto kt_lds_window =
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
// OGrad tile in LDS
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
// OGradT tile in LDS
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto dot_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor<Problem>());
auto dot_lds_window =
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kK1>{}), {0, 0});
// SGrad tile in LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto biast_lds = make_tensor_view<address_space_enum::lds>(
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
auto biast_lds_shuffle_window =
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto dbiast_lds_shuffle_window =
make_tile_window(biast_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// Block GEMM // Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>(); constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>(); constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
...@@ -250,34 +144,19 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -250,34 +144,19 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>(); constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>(); constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
auto v = load_tile(v_dram_window); // persistent V register tile
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
// init VGrad & KGrad // init VGrad & KGrad
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
clear_tile(dv_acc); // K, HBM ->LDS ->Reg
clear_tile(dk_acc); auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>());
auto k_dram_window = make_tile_window(
k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier(0);
const auto k_origin = k_dram_window.get_window_origin(); const auto k_origin = k_dram_window.get_window_origin();
// Early termination
const auto [seqlen_q_start, seqlen_q_end] = const auto [seqlen_q_start, seqlen_q_end] =
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}); mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
...@@ -290,272 +169,444 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -290,272 +169,444 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
{ {
// Note: here dk_acc&dv_acc are all cleard, return it // Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it. // Note: v loaded but no fence, ignore it.
return ck_tile::make_tuple(dk_acc, dv_acc); return make_tuple(dk_acc, dv_acc);
} }
} }
KDataType* k_lds_ptr =
static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_block_tile = load_tile(k_dram_window); auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
auto k_lds_read_window =
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK0>{}),
k_lds_write_window.get_window_origin(),
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
Policy::template MakeKRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// V, HBM ->LDS ->Reg
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVDramTileDistribution<Problem>());
VDataType* v_lds_ptr =
static_cast<VDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto v_lds = make_tensor_view<address_space_enum::lds>(
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
auto v_lds_write_window =
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
auto v_lds_read_window =
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK2>{}),
v_lds_write_window.get_window_origin(),
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
Policy::template MakeVRegBlockDescriptor<Problem>());
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS //------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
auto shuffled_k_block_tile = make_static_distributed_tensor<KDataType>(
Policy::template MakeShuffledKRegWriteBlockDescriptor<Problem>());
auto kt_dram_block_window = kt_dram_block_window_tmp; KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto shuffled_k_lds_write = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
auto shuffled_k_lds_write_window = make_tile_window(
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
auto kt_lds_read_window =
make_tile_window(kt_lds_read,
make_tuple(number<kQKHeaddim>{}, number<kN0>{}),
{0, 0},
Policy::template MakeKTRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// Pre-Load KV into Registers
auto k_block_tile = load_tile(k_dram_window);
auto v_block_tile = load_tile(v_dram_window);
auto kt_dram_window = make_tile_window( store_tile(k_lds_write_window, k_block_tile);
kt_dram_block_window.get_bottom_tensor_view(), shuffle_tile(shuffled_k_block_tile, k_block_tile);
kt_dram_block_window.get_window_lengths(), store_tile(shuffled_k_lds_write_window, shuffled_k_block_tile);
kt_dram_block_window.get_window_origin(),
Policy::template MakeKTDramTileDistribution<Problem>()); // K^T DRAM tile window for
// load
auto kt_block_tile = load_tile(kt_dram_window); block_sync_lds();
k_reg_tensor = load_tile(k_lds_read_window);
block_sync_lds();
auto kt_shuffle_tmp = make_static_distributed_tensor<KDataType>( auto kt_reg_tensor = load_tile(kt_lds_read_window);
Policy::template MakeShuffledKTRegBlockDescriptor<Problem>());
shuffle_tile(kt_shuffle_tmp, kt_block_tile);
store_tile(kt_lds_window, kt_shuffle_tmp); // persistent K^T in LDS store_tile(v_lds_write_window, v_block_tile);
auto q_dram_block_window = block_sync_lds();
v_reg_tensor = load_tile(v_lds_read_window);
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0}); {seqlen_q_start, 0},
Policy::template MakeQDramTileDistribution<Problem>());
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
auto q_lds_read_window =
make_tile_window(q_lds_window.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kK0>{}),
q_lds_window.get_window_origin(),
Policy::template MakeQRegSliceBlockDescriptor<Problem>());
auto qt_dram_block_window = auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(), Policy::template MakePTRegSliceBlockDescriptor<Problem>());
qt_dram_block_window_tmp.get_window_lengths(), // QT: Reg -> Reg-> LDS
{0, seqlen_q_start}); auto shuffled_q_block_tile = make_static_distributed_tensor<QDataType>(
Policy::template MakeShuffledQRegWriteBlockDescriptor<Problem>());
auto do_dram_block_window = QDataType* qt_lds_ptr =
static_cast<QDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto shuffled_q_lds_write = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
auto shuffled_q_lds_write_window = make_tile_window(
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
auto qt_lds_read_window =
make_tile_window(qt_lds_read,
make_tuple(number<kQKHeaddim>{}, number<kM0>{}),
{0, 0},
Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
// dO: HBM ->Reg ->LDS
auto do_dram_window =
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
do_dram_block_window_tmp.get_window_lengths(), do_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0}); {seqlen_q_start, 0},
Policy::template MakeOGradDramTileDistribution<Problem>());
auto dot_dram_block_window = OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(), static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>()));
dot_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_q_start});
auto dq_dram_block_window = auto do_lds = make_tensor_view<address_space_enum::lds>(
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto lse_dram_block_window = auto do_lds_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start}); auto do_lds_read_window =
make_tile_window(do_lds_window.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kK2>{}),
do_lds_window.get_window_origin(),
Policy::template MakeOGradRegSliceBlockDescriptor<Problem>());
// dOT: Reg ->Reg ->LDS
auto shuffled_do_block_tile = make_static_distributed_tensor<OGradDataType>(
Policy::template MakeShuffledOGradRegWriteBlockDescriptor<Problem>());
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>()));
auto shuffled_do_lds_write = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
auto d_dram_block_window = auto shuffled_do_lds_write_window = make_tile_window(
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start}); auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
auto dot_lds_read_window =
make_tile_window(dot_read_lds,
make_tuple(number<kVHeaddim>{}, number<kM0>{}),
{0, 0},
Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
// dS: Reg -> Reg -> LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeD<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto ds_lds_read_window =
make_tile_window(ds_lds_window.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kK4>{}),
ds_lds_window.get_window_origin(),
Policy::template MakeSGradRegSliceBlockDescriptor<Problem>());
auto dst_reg_tensor = make_static_distributed_tensor<GemmDataType>(
Policy::template MakeSGradTRegSliceBlockDescriptor<Problem>());
// Bias: HBM ->Reg ->Reg ->LDS
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_block_window =
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(), bias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N {seqlen_q_start, bias_origin.at(number<1>{})},
Policy::template MakeBiasTileDistribution<Problem>());
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); BiasDataType* bias_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
auto dbias_dram_block_window = static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), Policy::template GetSmemSizeOGrad<Problem>() +
dbias_dram_block_window_tmp.get_window_lengths(), Policy::template GetSmemSizeOGradT<Problem>() +
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeD<Problem>()));
auto bias_lds = make_tensor_view<address_space_enum::lds>(
bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor<Problem>());
auto qt_dram_window = auto bias_lds_write_window =
make_tile_window(qt_dram_block_window.get_bottom_tensor_view(), make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
qt_dram_block_window.get_window_lengths(),
qt_dram_block_window.get_window_origin(),
Policy::template MakeQTDramTileDistribution<Problem>());
auto dot_dram_window = auto bias_s_lds_read_window =
make_tile_window(dot_dram_block_window.get_bottom_tensor_view(), make_tile_window(bias_lds_write_window.get_bottom_tensor_view(),
dot_dram_block_window.get_window_lengths(), bias_lds_write_window.get_window_lengths(),
dot_dram_block_window.get_window_origin(), bias_lds_write_window.get_window_origin(),
Policy::template MakeOGradTDramTileDistribution<Problem>()); Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// LSE: HBM -> LDS ->Reg
auto lse_dram_window = make_tile_window( auto lse_dram_window = make_tile_window(
lse_dram_block_window.get_bottom_tensor_view(), lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window.get_window_lengths(), lse_dram_block_window_tmp.get_window_lengths(),
lse_dram_block_window.get_window_origin(), {seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
LSEDataType* lse_lds_ptr = static_cast<LSEDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template GetSmemSizeQ<Problem>()));
auto lse_lds = make_tensor_view<address_space_enum::lds>(
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
auto lse_lds_read_window = make_tile_window(
lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
// D: HBM ->Reg
auto d_dram_window = make_tile_window( auto d_dram_window = make_tile_window(
d_dram_block_window.get_bottom_tensor_view(), d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window.get_window_lengths(), d_dram_block_window_tmp.get_window_lengths(),
d_dram_block_window.get_window_origin(), {seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window = DDataType* d_lds_ptr = static_cast<DDataType*>(static_cast<void*>(
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
bias_dram_block_window.get_window_lengths(), Policy::template GetSmemSizeOGrad<Problem>() +
bias_dram_block_window.get_window_origin(), Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template MakeBiasTileDistribution<Problem>()); Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>()));
auto biast_lds_window = auto d_lds = make_tensor_view<address_space_enum::lds>(
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
biast_lds_shuffle_window.get_window_lengths(),
biast_lds_shuffle_window.get_window_origin(),
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>( auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
auto d_lds_read_window = make_tile_window(
d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
// RandVal: HBM ->Reg
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
randval_dram_block_window_tmp, seqlen_q_start); randval_dram_block_window_tmp, seqlen_q_start);
index_t i_total_loops = 0; // BiasGrad
constexpr index_t k0_loops = kQKHeaddim / kK0; // Reg ->LDS ->Reg ->HBM
constexpr index_t k1_loops = kM0 / kK1; const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
constexpr index_t k2_loops = kVHeaddim / kK2;
constexpr index_t k3_loops = kM0 / kK3; auto dbias_dram_window =
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
auto dbias_lds_read_window =
make_tile_window(bias_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
// ----------------------------Loop write out------------------------------//
auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
using SPBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
index_t i_total_loops = 0;
index_t seqlen_q_step = seqlen_q_start;
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
static_assert(kM0 == kK1, "kM0 should equal to kK1");
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4; constexpr index_t k4_loops = kN0 / kK4;
do
{
auto q_dram_window = make_tile_window(
q_dram_block_window.get_bottom_tensor_view(),
q_dram_block_window.get_window_lengths(),
q_dram_block_window.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
// load
auto do_dram_window = make_tile_window(
do_dram_block_window.get_bottom_tensor_view(),
do_dram_block_window.get_window_lengths(),
do_dram_block_window.get_window_origin(),
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0 /*
auto st_acc = SPTBlockTileType{}; * Prefetch Q, LSE, dO, D
*/
auto q_block_tile = load_tile(q_dram_window);
move_tile_window(q_dram_window, {kM0, 0});
auto lse_block_tile = load_tile(lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
auto q_block_tile = load_tile(q_dram_window); auto do_block_tile = load_tile(do_dram_window);
{ move_tile_window(do_dram_window, {kM0, 0});
move_tile_window(q_dram_window, {0, kK0});
clear_tile(st_acc); // Initialize S^T auto d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0});
store_tile(q_lds_window, q_block_tile); // LDS write 0 /*
q_block_tile = load_tile(q_dram_window); // global read 1 * Store prefetched data into LDS
} */
block_sync_lds();
store_tile(q_lds_window, q_block_tile);
shuffle_tile(shuffled_q_block_tile, q_block_tile);
store_tile(shuffled_q_lds_write_window, shuffled_q_block_tile);
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) store_tile(lse_lds_write_window, lse_block_tile);
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 2) store_tile(do_lds_window, do_block_tile);
{ shuffle_tile(shuffled_do_block_tile, do_block_tile);
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { store_tile(shuffled_do_lds_write_window, shuffled_do_block_tile);
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
block_sync_lds();
move_tile_window(q_dram_window, {0, kK0});
store_tile(q_lds_window,
q_block_tile); // LDS write i + 1
q_block_tile = load_tile(q_dram_window); // global read i + 2
});
}
const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile store_tile(d_lds_write_window, d_block_tile);
{ // tail block_sync_lds();
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 2) * kK0>{},
sequence<kN0, (k0_loops - 1) * kK0>{}));
block_sync_lds();
store_tile(q_lds_window, q_block_tile); /*
block_sync_lds(); * Prefetch LDS data into Reg to Asynchronous Data Movement and MFMA pipeline
*/
gemm_0(st_acc, auto q_reg_tensor = load_tile(q_lds_read_window);
q_lds_window, auto lse = load_tile(lse_lds_read_window);
get_slice_tile(k_lds_window, auto do_reg_tensor = load_tile(do_lds_read_window);
sequence<0, (k0_loops - 1) * kK0>{}, auto d = load_tile(d_lds_read_window);
sequence<kN0, k0_loops * kK0>{}));
} clear_tile(dv_acc);
clear_tile(dk_acc);
__builtin_amdgcn_sched_barrier(0);
// Hot loop
while(i_total_loops < (num_total_loop - 1))
{
// STAGE 1, Q@K Gemm0
auto s_acc = SPBlockTileType{};
q_block_tile = load_tile(q_dram_window);
move_tile_window(q_dram_window, {kM0, 0});
lse_block_tile = load_tile(lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
do_block_tile = load_tile(do_dram_window);
move_tile_window(do_dram_window, {kM0, 0});
d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0});
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
auto dot_reg_tensor = load_tile(dot_lds_read_window);
HotLoopScheduler::template GemmStagedScheduler<0>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
block_sync_lds(); const auto bias_tile = load_tile(bias_dram_window);
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>( auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>()); Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile); shuffle_tile(shuffled_bias_tile, bias_tile);
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); store_tile(bias_lds_write_window, shuffled_bias_tile);
block_sync_lds(); block_sync_lds();
auto biast_tile = load_tile(biast_lds_window); auto bias_s_tile = load_tile(bias_s_lds_read_window);
tile_elementwise_inout( tile_elementwise_inout(
[&](auto& x, const auto& y) { [&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = raw_scale * x + type_convert<AccDataType>(y);
#else
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y); x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
#endif
}, },
st_acc, s_acc,
biast_tile); bias_s_tile);
move_tile_window(bias_dram_window, {kM0, 0}); move_tile_window(bias_dram_window, {kM0, 0});
__builtin_amdgcn_sched_barrier(0);
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
const auto q_origin = q_dram_block_window.get_window_origin(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices( const auto tile_idx = get_x_indices_from_distributed_indices(
st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if !CK_TILE_FMHA_FWD_FAST_EXP2 s_acc(i_j_idx) *= scale;
st_acc(i_j_idx) *= raw_scale; position_encoding.update(s_acc(i_j_idx), row, col);
#else
st_acc(i_j_idx) *= scale;
#endif
position_encoding.update(st_acc(i_j_idx), row, col);
}); });
}); });
} }
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
#endif
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking) if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{ {
const auto q_origin = q_dram_block_window.get_window_origin(); bool need_perpixel_check = mask.IsEdgeTile(
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check) if(need_perpixel_check)
{ {
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) { set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col); return mask.IsOutOfBound(row, col);
}); });
} }
} }
const auto lse = load_tile(lse_dram_window);
static const auto get_validated_lse = [](LSEDataType raw_lse) { static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking) FmhaMask::IsMasking)
...@@ -570,278 +621,416 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -570,278 +621,416 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
} }
}; };
auto pt = SPTBlockTileType{}; auto p = SPBlockTileType{};
constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); constexpr auto p_spans = decltype(p)::get_distributed_spans();
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2 auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
#endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI) BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
} }
else else
{ {
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
} }
#else
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
#endif
}); });
}); });
auto dot_shuffle_tmp = make_static_distributed_tensor<OGradDataType>( if constexpr(FmhaDropout::IsDropout)
Policy::template MakeShuffledOGradTRegBlockDescriptor<Problem>());
block_sync_lds();
{
shuffle_tile(dot_shuffle_tmp, dot_prefetch);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
}
move_tile_window(dot_dram_window, {0, kK1});
if constexpr(kHasDropout)
{ {
dropout.Run<decltype(gemm_0), RandValOutputDataType>( dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window);
} }
const auto p_gemm = [&]() {
// STAGE 3, P^T@OGrad^T Gemm1 if constexpr(FmhaDropout::IsDropout)
const auto pt_gemm = [&]() {
if constexpr(kHasDropout)
{ {
return tile_elementwise_in( return tile_elementwise_in(
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); }, [](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt); p);
} }
else else
{ {
return cast_tile<GemmDataType>(pt); return cast_tile<GemmDataType>(p);
} }
}(); }();
if constexpr(k1_loops > 1) // STAGE 3, P^T@OGrad^T Gemm1
{ Policy::template PTFromGemm0CToGemm1A<Problem,
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { decltype(pt_reg_tensor),
const auto dot = load_tile(dot_dram_window); // load next OGrad^T decltype(p_gemm)>(pt_reg_tensor, p_gemm);
block_sync_lds(); gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
gemm_1(dv_acc,
get_slice_tile(pt_gemm,
sequence<i_k1 * kK1, 0>{},
sequence<(i_k1 + 1) * kK1, kN0>{}),
dot_lds_window);
block_sync_lds();
shuffle_tile(dot_shuffle_tmp, dot);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
move_tile_window(dot_dram_window, {0, kK1});
});
}
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
// tail
{
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(
pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence<kM0, kN0>{}),
dot_lds_window);
block_sync_lds();
}
// STAGE 4, OGrad@V Gemm2 auto qt_reg_tensor = load_tile(qt_lds_read_window);
auto dpt_acc = SPGradTBlockTileType{};
{ HotLoopScheduler::template GemmStagedScheduler<1>();
move_tile_window(do_dram_window, {0, kK2}); __builtin_amdgcn_sched_barrier(0);
// STAGE 4, OGrad@V Gemm2
auto dp_acc = SPGradBlockTileType{};
clear_tile(dpt_acc); // Initialize PGrad^T dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
store_tile(do_lds_window, do_block_tile); // LDS write 0 block_sync_lds();
do_block_tile = load_tile(do_dram_window); // global read 1
}
if constexpr(k2_loops > 2) store_tile(q_lds_window, q_block_tile);
{ shuffle_tile(shuffled_q_block_tile, q_block_tile);
static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) { store_tile(shuffled_q_lds_write_window, shuffled_q_block_tile);
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
block_sync_lds();
move_tile_window(do_dram_window, {0, kK2});
store_tile(do_lds_window,
do_block_tile); // LDS write i + 1
do_block_tile = load_tile(do_dram_window); // global read i + 2
});
}
const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile store_tile(lse_lds_write_window, lse_block_tile);
{ // tail
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 2) * kK2>{},
sequence<kN0, (k2_loops - 1) * kK2>{}));
block_sync_lds();
store_tile(do_lds_window, do_block_tile); store_tile(do_lds_window, do_block_tile);
block_sync_lds(); shuffle_tile(shuffled_do_block_tile, do_block_tile);
store_tile(shuffled_do_lds_write_window, shuffled_do_block_tile);
gemm_2(dpt_acc, store_tile(d_lds_write_window, d_block_tile);
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 1) * kK2>{},
sequence<kN0, k2_loops * kK2>{}));
}
HotLoopScheduler::template GemmStagedScheduler<2>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 5, P^T(PGrad^T - D) // STAGE 5, P^T(PGrad^T - D)
const auto d = load_tile(d_dram_window); auto ds = SPGradBlockTileType{};
constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
auto dst = SPGradTBlockTileType{}; sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = pt[i_j_idx] >= 0; bool undrop_flag = p[i_j_idx] >= 0;
dst(i_j_idx) = ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
pt[i_j_idx] * ? (dp_acc[i_j_idx] - d[i_idx])
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); : d[i_idx]);
}); });
}); });
if constexpr(kHasBiasGrad) if constexpr(kHasBiasGrad)
{ {
const auto dbiast = [&]() { const auto dbias = [&]() {
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
return tile_elementwise_in( return tile_elementwise_in(
[&rp_undrop](const auto& x) { [&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop); return type_convert<BiasGradDataType>(x * rp_undrop);
}, },
dst); ds);
} }
else else
{ {
return cast_tile<BiasGradDataType>(dst); return cast_tile<BiasGradDataType>(ds);
} }
}(); }();
store_tile(biast_lds_shuffle_window, dbiast); store_tile(bias_lds_write_window, dbias);
block_sync_lds(); block_sync_lds();
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>( auto dbias_tile = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>()); Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); shuffle_tile(dbias_tile, shuffled_dbias_tile);
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); store_tile(dbias_dram_window, dbias_tile);
move_tile_window(dbias_dram_block_window, {kM0, 0}); move_tile_window(dbias_dram_window, {kM0, 0});
__builtin_amdgcn_sched_barrier(0);
} }
// STAGE 6, SGrad^T@Q^T Gemm3 // STAGE 6, SGrad^T@Q^T Gemm3
auto qt_shuffle_tmp = make_static_distributed_tensor<QDataType>( const auto ds_gemm = cast_tile<GemmDataType>(ds);
Policy::template MakeShuffledQTRegBlockDescriptor<Problem>());
Policy::template SGradTFromGemm2CToGemm3A<Problem,
decltype(dst_reg_tensor),
decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
store_tile(ds_lds_window, ds_gemm);
block_sync_lds(); block_sync_lds();
auto ds_reg_tensor = load_tile(ds_lds_read_window);
auto ds_reg_tensor_next = decltype(ds_reg_tensor){};
move_tile_window(ds_lds_read_window, {0, kK4});
q_reg_tensor = load_tile(q_lds_read_window);
lse = load_tile(lse_lds_read_window);
HotLoopScheduler::template GemmStagedScheduler<3>();
__builtin_amdgcn_sched_barrier(0);
// STAGE7 SGrad@K^T Gemm4
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc);
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor_next = load_tile(ds_lds_read_window);
move_tile_window(ds_lds_read_window, {0, kK4});
}
auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
}
});
move_tile_window(ds_lds_read_window, {0, -kN0});
do_reg_tensor = load_tile(do_lds_read_window);
d = load_tile(d_lds_read_window);
HotLoopScheduler::template GemmStagedScheduler<4>();
// QGrad Scale
if constexpr(FmhaDropout::IsDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
}
if constexpr(kIsDeterministic)
{ {
shuffle_tile(qt_shuffle_tmp, qt_prefetch); store_tile(dq_dram_window, dq_acc);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
} }
move_tile_window(qt_dram_window, {0, kK3}); else
{
update_tile(dq_dram_window, dq_acc);
}
move_tile_window(dq_dram_window, {kM0, 0});
i_total_loops += 1;
seqlen_q_step += kM0;
}
__builtin_amdgcn_sched_barrier(0);
// Tail
auto s_acc = SPBlockTileType{};
// STAGE 1, Q@K Gemm0
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
const auto dst_gemm = cast_tile<GemmDataType>(dst); // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
const auto bias_tile = load_tile(bias_dram_window);
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(shuffled_bias_tile, bias_tile);
store_tile(bias_lds_write_window, shuffled_bias_tile);
block_sync_lds();
auto bias_s_tile = load_tile(bias_s_lds_read_window);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
},
s_acc,
bias_s_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if constexpr(k3_loops > 1) s_acc(i_j_idx) *= scale;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
bool need_perpixel_check = mask.IsEdgeTile(
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
if(need_perpixel_check)
{ {
static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) { set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto qt = load_tile(qt_dram_window); // load next Q^T const auto row = seqlen_q_step + tile_idx.at(number<0>{});
block_sync_lds(); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
gemm_3(dk_acc, return mask.IsOutOfBound(row, col);
get_slice_tile(dst_gemm,
sequence<i_k3 * kK3, 0>{},
sequence<(i_k3 + 1) * kK3, kN0>{}),
qt_lds_window);
block_sync_lds();
shuffle_tile(qt_shuffle_tmp, qt);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
move_tile_window(qt_dram_window, {0, kK3});
}); });
} }
// tail }
static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{ {
block_sync_lds(); return raw_lse == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
gemm_3(dk_acc, : raw_lse;
get_slice_tile(
dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence<kM0, kN0>{}),
qt_lds_window);
block_sync_lds();
} }
else
{
return raw_lse;
}
};
// STAGE 7, SGrad@K^T Gemm4 auto p = SPBlockTileType{};
store_tile(ds_lds_window, dst_gemm); constexpr auto p_spans = decltype(p)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
auto dq_acc = QGradBlockTileType{}; sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
clear_tile(dq_acc); // Initialize QGrad constexpr auto i_j_idx = make_tuple(idx0, idx1);
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
}
else
{
p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
}
});
});
block_sync_lds(); if constexpr(FmhaDropout::IsDropout)
{
dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window);
}
static_for<0, k4_loops, 1>{}([&](auto i_k4) { // STAGE 3, P^T@OGrad^T Gemm1
gemm_4(dq_acc, const auto p_gemm = [&]() {
get_slice_tile(ds_lds_window, if constexpr(FmhaDropout::IsDropout)
sequence<0, i_k4 * kK4>{}, {
sequence<kM0, (i_k4 + 1) * kK4>{}), return tile_elementwise_in(
get_slice_tile(kt_lds_window, [](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); }, p);
sequence<0, i_k4 * kK4>{}, }
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{})); else
{
return cast_tile<GemmDataType>(p);
}
}();
Policy::template PTFromGemm0CToGemm1A<Problem, decltype(pt_reg_tensor), decltype(p_gemm)>(
pt_reg_tensor, p_gemm);
auto dot_reg_tensor = load_tile(dot_lds_read_window);
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<1>();
// STAGE 4, OGrad@V Gemm2
auto dp_acc = SPGradBlockTileType{};
auto qt_reg_tensor = load_tile(qt_lds_read_window);
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<2>();
// STAGE 5, P^T(PGrad^T - D)
auto ds = SPGradBlockTileType{};
constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = p[i_j_idx] >= 0;
ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
? (dp_acc[i_j_idx] - d[i_idx])
: d[i_idx]);
}); });
});
// QGrad Scale if constexpr(kHasBiasGrad)
if constexpr(kHasDropout) {
const auto dbias = [&]() {
if constexpr(FmhaDropout::IsDropout)
{
return tile_elementwise_in(
[&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop);
},
ds);
}
else
{
return cast_tile<BiasGradDataType>(ds);
}
}();
store_tile(bias_lds_write_window, dbias);
block_sync_lds();
auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
auto dbias_tile = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbias_tile, shuffled_dbias_tile);
store_tile(dbias_dram_window, dbias_tile);
}
// STAGE 6, SGrad^T@Q^T Gemm3
const auto ds_gemm = cast_tile<GemmDataType>(ds);
Policy::template SGradTFromGemm2CToGemm3A<Problem,
decltype(dst_reg_tensor),
decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
store_tile(ds_lds_window, ds_gemm);
block_sync_lds();
auto ds_reg_tensor = load_tile(ds_lds_read_window);
auto ds_reg_tensor_next = decltype(ds_reg_tensor){};
move_tile_window(ds_lds_read_window, {0, kK4});
HotLoopScheduler::template GemmStagedScheduler<3>();
// STAGE 7, SGrad@K^T Gemm4
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc);
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
if constexpr(i_k4 < k4_loops - 1)
{ {
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, ds_reg_tensor_next = load_tile(ds_lds_read_window);
dq_acc); move_tile_window(ds_lds_read_window, {0, kK4});
} }
else auto kt_reg_tensor_slice = get_slice_tile(
kt_reg_tensor, sequence<0, i_k4 * kK4>{}, sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
if constexpr(i_k4 < k4_loops - 1)
{ {
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
} }
const auto dq = cast_tile<QGradDataType>(dq_acc); });
update_tile(dq_dram_block_window, dq);
// move tile windows HotLoopScheduler::template GemmStagedScheduler<4>();
move_tile_window(q_dram_block_window, {kM0, 0});
move_tile_window(dq_dram_block_window, {kM0, 0});
move_tile_window(do_dram_block_window, {kM0, 0});
move_tile_window(lse_dram_window, {kM0});
move_tile_window(d_dram_window, {kM0});
} while(++i_total_loops < num_total_loop);
// KGrad Scale // Results Scale
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc); dk_acc);
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
} }
else else
{ {
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
} }
// VGrad Scale
if constexpr(kHasDropout) if constexpr(kIsDeterministic)
{ {
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); store_tile(dq_dram_window, dq_acc);
}
else
{
update_tile(dq_dram_window, dq_acc);
} }
return ck_tile::make_tuple(dk_acc, dv_acc); return make_tuple(dk_acc, dv_acc);
} }
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k & k^t located in lds.
using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ true,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy>
struct BlockFmhaBwdDQDKDVPipelineKSVR
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kQLoadOnce = false;
static constexpr bool kQTLoadOnce = false;
static constexpr bool kKLoadOnce = true;
static constexpr bool kKTLoadOnce = false;
static constexpr bool kVLoadOnce = true;
static constexpr bool kOGradLoadOnce = false;
static constexpr bool kOGradTLoadOnce = false;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad =
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr const char* name = "ks_vr";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename QTDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename KTDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename OGradDramBlockWindowTmp,
typename OGradTDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp,
typename QGradDramBlockWindowTmp,
typename BiasGradDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
const DDramBlockWindowTmp& d_dram_block_window_tmp,
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
FmhaMask mask,
PositionEncoding position_encoding,
float raw_scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float scale,
#endif
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
BlockDropout& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<QDataType,
remove_cvref_t<typename QTDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradTDramBlockWindowTmp::DataType>> &&
std::is_same_v<LSEDataType,
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kVHeaddim ==
OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// Q tile in LDS
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
// QT tile in LDS
QDataType* qt_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto qt_lds = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor<Problem>());
auto qt_lds_window =
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kK3>{}), {0, 0});
// K tile in LDS
auto k_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
// KT tile in LDS
auto kt_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptorAsKT<Problem>());
auto kt_lds_window =
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
// OGrad tile in LDS
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
// OGradT tile in LDS
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto dot_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor<Problem>());
auto dot_lds_window =
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kK1>{}), {0, 0});
// SGrad tile in LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto biast_lds = make_tensor_view<address_space_enum::lds>(
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
auto biast_lds_shuffle_window =
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto dbiast_lds_shuffle_window =
make_tile_window(biast_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
auto v = load_tile(v_dram_window); // persistent V register tile
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
// init VGrad & KGrad
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
clear_tile(dv_acc);
clear_tile(dk_acc);
auto k_dram_window = make_tile_window(
k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier(0);
const auto k_origin = k_dram_window.get_window_origin();
const auto [seqlen_q_start, seqlen_q_end] =
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking)
{
if(num_total_loop <= 0)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return ck_tile::make_tuple(dk_acc, dv_acc);
}
}
auto k_block_tile = load_tile(k_dram_window);
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
auto q_dram_block_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto qt_dram_block_window =
make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(),
qt_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_q_start});
auto do_dram_block_window =
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
do_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto dot_dram_block_window =
make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(),
dot_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_q_start});
auto dq_dram_block_window =
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto lse_dram_block_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
auto d_dram_block_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_block_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
auto dbias_dram_block_window =
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
auto qt_dram_window =
make_tile_window(qt_dram_block_window.get_bottom_tensor_view(),
qt_dram_block_window.get_window_lengths(),
qt_dram_block_window.get_window_origin(),
Policy::template MakeQTDramTileDistribution<Problem>());
auto dot_dram_window =
make_tile_window(dot_dram_block_window.get_bottom_tensor_view(),
dot_dram_block_window.get_window_lengths(),
dot_dram_block_window.get_window_origin(),
Policy::template MakeOGradTDramTileDistribution<Problem>());
auto lse_dram_window = make_tile_window(
lse_dram_block_window.get_bottom_tensor_view(),
lse_dram_block_window.get_window_lengths(),
lse_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_dram_window = make_tile_window(
d_dram_block_window.get_bottom_tensor_view(),
d_dram_block_window.get_window_lengths(),
d_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window =
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
bias_dram_block_window.get_window_lengths(),
bias_dram_block_window.get_window_origin(),
Policy::template MakeBiasTileDistribution<Problem>());
auto biast_lds_window =
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
biast_lds_shuffle_window.get_window_lengths(),
biast_lds_shuffle_window.get_window_origin(),
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
randval_dram_block_window_tmp, seqlen_q_start);
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kM0 / kK1;
constexpr index_t k2_loops = kVHeaddim / kK2;
constexpr index_t k3_loops = kM0 / kK3;
constexpr index_t k4_loops = kN0 / kK4;
do
{
auto q_dram_window = make_tile_window(
q_dram_block_window.get_bottom_tensor_view(),
q_dram_block_window.get_window_lengths(),
q_dram_block_window.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
// load
auto do_dram_window = make_tile_window(
do_dram_block_window.get_bottom_tensor_view(),
do_dram_block_window.get_window_lengths(),
do_dram_block_window.get_window_origin(),
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto st_acc = SPTBlockTileType{};
auto q_block_tile = load_tile(q_dram_window);
{
move_tile_window(q_dram_window, {0, kK0});
clear_tile(st_acc); // Initialize S^T
store_tile(q_lds_window, q_block_tile); // LDS write 0
q_block_tile = load_tile(q_dram_window); // global read 1
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 2)
{
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
block_sync_lds();
move_tile_window(q_dram_window, {0, kK0});
store_tile(q_lds_window,
q_block_tile); // LDS write i + 1
q_block_tile = load_tile(q_dram_window); // global read i + 2
});
}
const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile
{ // tail
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 2) * kK0>{},
sequence<kN0, (k0_loops - 1) * kK0>{}));
block_sync_lds();
store_tile(q_lds_window, q_block_tile);
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kN0, k0_loops * kK0>{}));
}
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
block_sync_lds();
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile);
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
block_sync_lds();
auto biast_tile = load_tile(biast_lds_window);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = raw_scale * x + type_convert<AccDataType>(y);
#else
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
#endif
},
st_acc,
biast_tile);
move_tile_window(bias_dram_window, {kM0, 0});
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto q_origin = q_dram_block_window.get_window_origin();
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc(i_j_idx) *= raw_scale;
#else
st_acc(i_j_idx) *= scale;
#endif
position_encoding.update(st_acc(i_j_idx), row, col);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
#endif
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto q_origin = q_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
const auto lse = load_tile(lse_dram_window);
static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_lse == -numeric<LSEDataType>::infinity()
? type_convert<LSEDataType>(0.f)
: raw_lse;
}
else
{
return raw_lse;
}
};
auto pt = SPTBlockTileType{};
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
#endif
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
}
else
{
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
}
#else
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
#endif
});
});
auto dot_shuffle_tmp = make_static_distributed_tensor<OGradDataType>(
Policy::template MakeShuffledOGradTRegBlockDescriptor<Problem>());
block_sync_lds();
{
shuffle_tile(dot_shuffle_tmp, dot_prefetch);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
}
move_tile_window(dot_dram_window, {0, kK1});
if constexpr(kHasDropout)
{
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
}
// STAGE 3, P^T@OGrad^T Gemm1
const auto pt_gemm = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt);
}
else
{
return cast_tile<GemmDataType>(pt);
}
}();
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto dot = load_tile(dot_dram_window); // load next OGrad^T
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(pt_gemm,
sequence<i_k1 * kK1, 0>{},
sequence<(i_k1 + 1) * kK1, kN0>{}),
dot_lds_window);
block_sync_lds();
shuffle_tile(dot_shuffle_tmp, dot);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
move_tile_window(dot_dram_window, {0, kK1});
});
}
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
// tail
{
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(
pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence<kM0, kN0>{}),
dot_lds_window);
block_sync_lds();
}
// STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{};
{
move_tile_window(do_dram_window, {0, kK2});
clear_tile(dpt_acc); // Initialize PGrad^T
store_tile(do_lds_window, do_block_tile); // LDS write 0
do_block_tile = load_tile(do_dram_window); // global read 1
}
if constexpr(k2_loops > 2)
{
static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) {
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
block_sync_lds();
move_tile_window(do_dram_window, {0, kK2});
store_tile(do_lds_window,
do_block_tile); // LDS write i + 1
do_block_tile = load_tile(do_dram_window); // global read i + 2
});
}
const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile
{ // tail
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 2) * kK2>{},
sequence<kN0, (k2_loops - 1) * kK2>{}));
block_sync_lds();
store_tile(do_lds_window, do_block_tile);
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 1) * kK2>{},
sequence<kN0, k2_loops * kK2>{}));
}
// STAGE 5, P^T(PGrad^T - D)
const auto d = load_tile(d_dram_window);
auto dst = SPGradTBlockTileType{};
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = pt[i_j_idx] >= 0;
dst(i_j_idx) =
pt[i_j_idx] *
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
});
});
if constexpr(kHasBiasGrad)
{
const auto dbiast = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop);
},
dst);
}
else
{
return cast_tile<BiasGradDataType>(dst);
}
}();
store_tile(biast_lds_shuffle_window, dbiast);
block_sync_lds();
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
move_tile_window(dbias_dram_block_window, {kM0, 0});
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto qt_shuffle_tmp = make_static_distributed_tensor<QDataType>(
Policy::template MakeShuffledQTRegBlockDescriptor<Problem>());
block_sync_lds();
{
shuffle_tile(qt_shuffle_tmp, qt_prefetch);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
}
move_tile_window(qt_dram_window, {0, kK3});
const auto dst_gemm = cast_tile<GemmDataType>(dst);
if constexpr(k3_loops > 1)
{
static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) {
const auto qt = load_tile(qt_dram_window); // load next Q^T
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(dst_gemm,
sequence<i_k3 * kK3, 0>{},
sequence<(i_k3 + 1) * kK3, kN0>{}),
qt_lds_window);
block_sync_lds();
shuffle_tile(qt_shuffle_tmp, qt);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
move_tile_window(qt_dram_window, {0, kK3});
});
}
// tail
{
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(
dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence<kM0, kN0>{}),
qt_lds_window);
block_sync_lds();
}
// STAGE 7, SGrad@K^T Gemm4
store_tile(ds_lds_window, dst_gemm);
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc); // Initialize QGrad
block_sync_lds();
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
gemm_4(dq_acc,
get_slice_tile(ds_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kM0, (i_k4 + 1) * kK4>{}),
get_slice_tile(kt_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
});
// QGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
}
const auto dq = cast_tile<QGradDataType>(dq_acc);
update_tile(dq_dram_block_window, dq);
// move tile windows
move_tile_window(q_dram_block_window, {kM0, 0});
move_tile_window(dq_dram_block_window, {kM0, 0});
move_tile_window(do_dram_block_window, {kM0, 0});
move_tile_window(lse_dram_window, {kM0});
move_tile_window(d_dram_window, {kM0});
} while(++i_total_loops < num_total_loop);
// KGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
// VGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
}
return ck_tile::make_tuple(dk_acc, dv_acc);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k located in lds.
using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, q & k & do located in lds.
using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ true,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ true,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
...@@ -18,60 +20,215 @@ ...@@ -18,60 +20,215 @@
namespace ck_tile { namespace ck_tile {
template <bool QLoadOnce_,
bool QTLoadOnce_,
bool KLoadOnce_,
bool KTLoadOnce_,
bool VLoadOnce_,
bool OGradLoadOnce_,
bool OGradTLoadOnce_>
struct BlockFmhaBwdPipelineDefaultPolicy struct BlockFmhaBwdPipelineDefaultPolicy
{ {
static constexpr bool QLoadOnce = template <typename Problem>
QLoadOnce_; // if q load whole block length (qkhdim) to LDS at once CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
static constexpr bool QTLoadOnce = {
QTLoadOnce_; // if q^t load whole block length (qkhdim) to LDS at once using BlockGemmProblem =
static constexpr bool KLoadOnce = BlockGemmPipelineProblem<typename Problem::QDataType,
KLoadOnce_; // if k load whole block length (qkhdim) to LDS at once typename Problem::KDataType,
static constexpr bool KTLoadOnce = typename Problem::AccDataType,
KTLoadOnce_; // if k^t load whole block length (qkhdim) to LDS at once Problem::kBlockSize,
static constexpr bool VLoadOnce = TileGemmShape<Problem::BlockFmhaShape::kM0,
VLoadOnce_; // if v load whole block length (vhdim) to Vgprs at once Problem::BlockFmhaShape::kN0,
static constexpr bool OGradLoadOnce = Problem::BlockFmhaShape::kK0>>;
OGradLoadOnce_; // if do load whole block length (vhdim) to LDS at once
static constexpr bool OGradTLoadOnce = using WarpGemm = WarpGemmMfmaDispatcher<
OGradTLoadOnce_; // if do^t load whole block length (vhdim) to LDS at once typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
false>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
// these are for global load // these are for global load
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{ {
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
return 16 / sizeof(QDataType); constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(QDataType);
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (total_pixels / kMinVecLoad);
return kVecLoad;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
{ {
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
return 16 / sizeof(KDataType); constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(KDataType);
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (total_pixels / kMinVecLoad);
return kVecLoad;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{ {
if constexpr(VLoadOnce) using VDataType = remove_cvref_t<typename Problem::VDataType>;
{ constexpr index_t kBlockSize = Problem::kBlockSize;
using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
using WG = remove_cvref_t<decltype(config.template at<0>())>; constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType);
return WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
}
else return total_pixels > kMaxVecLoad ? kMaxVecLoad : total_pixels;
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
return 16 / sizeof(VDataType);
}
} }
template <typename Problem> template <typename Problem>
...@@ -84,20 +241,39 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -84,20 +241,39 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOGrad() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOGrad()
{ {
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>; using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
return 16 / sizeof(OGradDataType); constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType);
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (total_pixels / kMinVecLoad);
return kVecLoad;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQGrad() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias()
{ {
using BlockGemm = remove_cvref_t<decltype(GetSGradKTBlockGemm<Problem>())>; using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr index_t kBlockSize = Problem::kBlockSize;
using WG = remove_cvref_t<decltype(config.template at<0>())>; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
using CWarpDstr = typename WG::CWarpDstr; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr auto vec = constexpr index_t kMaxVecLoad = 16 / sizeof(BiasDataType);
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number<CWarpDstr::NDimY - 1>{}); constexpr index_t kMinVecLoad = 4 / sizeof(BiasDataType);
return vec;
constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (total_pixels / kMinVecLoad);
return kVecLoad;
} }
template <typename Problem> template <typename Problem>
...@@ -128,60 +304,35 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -128,60 +304,35 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentQ() CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentQ()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = [&]() { constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
if constexpr(QTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK3;
}();
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
// TODO: not correct! return total_pixels / GetAlignmentQ<Problem>();
if constexpr(total_pixels > 4)
return 4;
else
return 2;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentK() CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentK()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = [&]() { constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
if constexpr(KTLoadOnce)
return Problem::BlockFmhaShape::kN0;
else
return Problem::BlockFmhaShape::kK4;
}();
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
// TODO: not correct! return total_pixels / GetAlignmentK<Problem>();
if constexpr(total_pixels > 4)
return 4;
else
return 2;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad() CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = [&]() { constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
if constexpr(OGradTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK1;
}();
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
// TODO: not correct! return total_pixels / GetAlignmentOGrad<Problem>();
if constexpr(total_pixels > 4)
return 4;
else
return 2;
} }
template <typename Problem> template <typename Problem>
...@@ -193,1151 +344,1577 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -193,1151 +344,1577 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize;
// TODO: not correct! return total_pixels / GetAlignmentBias<Problem>();
if constexpr(total_pixels > 32)
return 8;
else
return 4;
} }
// these are for lds
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentPostQGradAcc()
{ {
// TODO: this is for 3d layout using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using QDataType = remove_cvref_t<typename Problem::QDataType>; return 16 / sizeof(AccDataType);
return 16 / sizeof(QDataType);
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentPostQGrad()
{ {
// TODO: this is for 3d layout return GetAlignmentPostQGradAcc<Problem>();
using KDataType = remove_cvref_t<typename Problem::KDataType>;
return 16 / sizeof(KDataType);
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
{ {
// TODO: this is for 3d layout constexpr index_t kBlockSize = Problem::kBlockSize;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
return 16 / sizeof(VDataType); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t K1 = GetAlignmentK<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N1 = get_warp_size() / K0;
constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t N2 = kNPerBlock / (N1 * N0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias() CK_TILE_HOST_DEVICE static constexpr auto MakeVDramTileDistribution()
{ {
// TODO: this is for 3d layout constexpr index_t kBlockSize = Problem::kBlockSize;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
return 16 / sizeof(BiasDataType); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t K1 = GetAlignmentV<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGrad() CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{ {
// TODO: this is for 3d layout constexpr index_t kBlockSize = Problem::kBlockSize;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
return 16 / sizeof(OGradDataType); constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M1 = get_warp_size() / K0;
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M2 = kMPerBlock / (M1 * M0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackSGrad() CK_TILE_HOST_DEVICE static constexpr auto MakeOGradDramTileDistribution()
{ {
// TODO: this is for 3d layout constexpr index_t kBlockSize = Problem::kBlockSize;
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
return 16 / sizeof(GemmDataType);
}
template <typename Problem, typename BlockGemm> constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
CK_TILE_HOST_DEVICE static constexpr auto MakeVInRegDramTileDistribution() constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr index_t K1 = GetAlignmentOGrad<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M1 = get_warp_size() / K0;
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M2 = kMPerBlock / (M1 * M0);
using WG = remove_cvref_t<decltype(config.template at<0>())>; return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution()
{
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
constexpr index_t MWarp = config.template at<1>(); constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>(); constexpr index_t NWarp = config.template at<2>();
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
constexpr auto v_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( // Duplicate dimension
v_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); constexpr index_t N0 = NWarp;
constexpr index_t N1 =
(get_warp_size() / kMPerBlock) > 1 ? (get_warp_size() / kMPerBlock) : 1;
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode); constexpr index_t M0 = MWarp;
constexpr index_t M1 = (get_warp_size() / kMPerBlock) > 1 ? kMPerBlock : get_warp_size();
constexpr index_t M2 =
(get_warp_size() / kMPerBlock) > 1 ? 1 : (kMPerBlock / get_warp_size());
return v_block_dstr; return make_static_tile_distribution(
tile_distribution_encoding<sequence<N0, N1>,
tuple<sequence<M0, M1, M2>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1>,
sequence<2>>{});
} }
// 3d + padding template <typename Problem>
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack> CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor()
{ {
constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr index_t kBlockSize = Problem::kBlockSize;
make_tuple(number<KPerBlock / KPack>{}, number<MNPerBlock>{}, number<KPack>{}),
make_tuple(number<(MNPerBlock + 1) * KPack>{}, number<KPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto x_lds_block_desc = transform_tensor_descriptor(
x_lds_block_desc_0,
make_tuple(make_pass_through_transform(MNPerBlock),
make_merge_transform(make_tuple(KPerBlock / KPack, KPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return x_lds_block_desc;
}
// 3d + padding constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack> constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptorAsXT()
{
constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MNPerBlock>{}, number<KPack>{}),
make_tuple(number<(MNPerBlock + 1) * KPack>{}, number<KPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto xt_lds_block_desc = transform_tensor_descriptor( constexpr index_t N1 = GetAlignmentBias<Problem>();
x_lds_block_desc_0, constexpr index_t N0 = kNPerBlock / N1;
make_tuple(make_pass_through_transform(MNPerBlock), constexpr index_t M1 = get_warp_size() / N0;
make_merge_transform(make_tuple(KPerBlock / KPack, KPack))), constexpr index_t M0 = kBlockSize / get_warp_size();
make_tuple(sequence<1>{}, sequence<0, 2>{}), constexpr index_t M2 = kMPerBlock / (M1 * M0);
make_tuple(sequence<1>{}, sequence<0>{}));
return xt_lds_block_desc; return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
} }
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack, index_t PixelsPerRow> template <typename DataType, index_t MPerBlock, index_t KPerBlock>
CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakePreXDramTileDistribution()
{ {
static_assert(PixelsPerRow % KPack == 0); constexpr index_t K1 = 16 / sizeof(DataType);
constexpr index_t NPerRow = PixelsPerRow / KPack; constexpr index_t K0 = KPerBlock / K1;
static_assert(MNPerBlock % NPerRow == 0); constexpr index_t M2 = 1;
static_assert(KPerBlock % KPack == 0); constexpr index_t M1 = get_warp_size();
constexpr index_t M0 = MPerBlock / M1;
constexpr auto xt_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{},
number<MNPerBlock / NPerRow>{},
number<NPerRow>{},
number<KPack>{}),
make_tuple(number<(MNPerBlock / NPerRow) * (PixelsPerRow + KPack)>{},
number<PixelsPerRow + KPack>{},
number<KPack>{},
number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto xt_lds_block_desc = transform_tensor_descriptor(
xt_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<MNPerBlock / NPerRow>{}, number<NPerRow>{})),
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return xt_lds_block_desc; return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1>>,
tuple<sequence<0>, sequence<1>>,
sequence<1, 2, 2>,
sequence<2, 0, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakePreODramTileDistribution()
{ {
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; using ODataType = remove_cvref_t<typename Problem::ODataType>;
constexpr index_t kKPerBlock = [&]() {
if constexpr(QLoadOnce)
return Problem::BlockFmhaShape::kQKHeaddim;
else
return Problem::BlockFmhaShape::kK0;
}();
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
}
template <typename Problem> constexpr index_t kBlockSize = Problem::kBlockSize;
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptorAsQT() constexpr index_t kKPerBlock = Problem::kVHeaddim;
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = [&]() {
if constexpr(QLoadOnce)
return Problem::BlockFmhaShape::kQKHeaddim;
else
return Problem::BlockFmhaShape::kK0;
}();
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
return MakeXLdsBlockDescriptorAsXT<kMPerBlock, kKPerBlock, kKPack>(); return MakePreXDramTileDistribution<ODataType, kBlockSize, kKPerBlock>();
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakePreOGradDramTileDistribution()
{ {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
constexpr index_t kKPerBlock = [&]() {
if constexpr(KLoadOnce)
return Problem::BlockFmhaShape::kQKHeaddim;
else
return Problem::BlockFmhaShape::kK0;
}();
constexpr index_t kKPack = GetSmemKPackK<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>(); constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::kVHeaddim;
return MakePreXDramTileDistribution<OGradDataType, kBlockSize, kKPerBlock>();
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptorAsKT() CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccDramTileDistribution()
{ {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
constexpr index_t kKPerBlock = [&]() {
if constexpr(KLoadOnce)
return Problem::BlockFmhaShape::kQKHeaddim;
else
return Problem::BlockFmhaShape::kK0;
}();
constexpr index_t kKPack = GetSmemKPackK<Problem>();
return MakeXLdsBlockDescriptorAsXT<kNPerBlock, kKPerBlock, kKPack>(); constexpr index_t kBlockSize = Problem::kBlockSize;
} constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kKPerBlock = Problem::kQKHeaddim;
template <typename Problem> constexpr index_t K1 = 16 / sizeof(AccDataType);
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() constexpr index_t K0 = kKPerBlock / K1;
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>(); constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M1 * M2);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<1>, sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<2>, sequence<2, 3>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2, 3>,
sequence<0, 0, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradDramTileDistribution()
{ {
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
constexpr index_t kKPerBlock = [&]() {
if constexpr(OGradLoadOnce)
return Problem::BlockFmhaShape::kVHeaddim;
else
return Problem::BlockFmhaShape::kK2;
}();
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>(); constexpr index_t kBlockSize = Problem::kBlockSize;
} constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kKPerBlock = Problem::kQKHeaddim;
template <typename Problem> constexpr index_t K1 = 16 / sizeof(AccDataType);
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptorAsOGradT() constexpr index_t K0 = kKPerBlock / K1;
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t kKPerBlock = [&]() { constexpr index_t M1 = kBlockSize / get_warp_size();
if constexpr(OGradLoadOnce) constexpr index_t M0 = kMPerBlock / (M1 * M2);
return Problem::BlockFmhaShape::kVHeaddim;
else
return Problem::BlockFmhaShape::kK2;
}();
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
return MakeXLdsBlockDescriptorAsXT<kMPerBlock, kKPerBlock, kKPack>(); return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
} }
// these are for lds
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSGradLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
{ {
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; return GetAlignmentQ<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPack = GetSmemKPackSGrad<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQTLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQT()
{ {
using QDataType = remove_cvref_t<typename Problem::QDataType>; return GetTransposedAlignmentQ<Problem>();
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QDataType);
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = [&]() {
if constexpr(QTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK3;
}();
return MakeXTLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack, PixelsPerRow>();
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKTLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
{ {
using KDataType = remove_cvref_t<typename Problem::KDataType>; return GetAlignmentK<Problem>();
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(KDataType);
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = [&]() {
if constexpr(KTLoadOnce)
return Problem::BlockFmhaShape::kN0;
else
return Problem::BlockFmhaShape::kK4;
}();
return MakeXTLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack, PixelsPerRow>();
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackKT()
{ {
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>; return GetTransposedAlignmentK<Problem>();
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(OGradDataType);
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPerBlock = [&]() {
if constexpr(OGradTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK1;
}();
return MakeXTLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack, PixelsPerRow>();
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
{ {
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>; return GetAlignmentV<Problem>();
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(BiasDataType);
constexpr index_t kKPack = GetSmemKPackBias<Problem>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kMPerBlock % kKPack == 0);
constexpr auto biast_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock / kKPack>{},
number<kNPerBlock / NPerRow>{},
number<NPerRow>{},
number<kKPack>{}),
make_tuple(number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
number<PixelsPerRow + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto biast_lds_block_desc = transform_tensor_descriptor(
biast_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
make_merge_transform(make_tuple(number<kMPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return biast_lds_block_desc;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias()
{ {
constexpr index_t smem_size_q = sizeof(typename Problem::QDataType) * return GetAlignmentBias<Problem>();
MakeQLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_q;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQT() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBiasT()
{ {
constexpr index_t smem_size_qt = [&]() { return GetTransposedAlignmentBias<Problem>();
if constexpr(QLoadOnce && !QTLoadOnce)
return 0;
else
return sizeof(typename Problem::QDataType) *
MakeQTLdsBlockDescriptor<Problem>().get_element_space_size();
}();
return smem_size_qt;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGrad()
{ {
constexpr index_t smem_size_k = sizeof(typename Problem::KDataType) * return GetAlignmentOGrad<Problem>();
MakeKLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_k;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKT() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGradT()
{ {
constexpr index_t smem_size_kt = [&]() { return GetTransposedAlignmentOGrad<Problem>();
if constexpr(KLoadOnce && !KTLoadOnce)
return 0;
else
return sizeof(typename Problem::KDataType) *
MakeKTLdsBlockDescriptor<Problem>().get_element_space_size();
}();
return smem_size_kt;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackSGrad()
{ {
constexpr index_t smem_size_v = [&]() { // TODO: this is for 3d layout
if constexpr(VLoadOnce) using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
return 0; return 16 / sizeof(GemmDataType);
else
return sizeof(typename Problem::VDataType) *
MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
}();
return smem_size_v;
} }
template <typename Problem> template <index_t MNPerBlock, index_t KPerBlock, index_t KPack>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOGrad() CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor()
{ {
constexpr index_t smem_size_do = constexpr auto DataTypeSize = 2; // sizeof(F16/BF16)
sizeof(typename Problem::OGradDataType) * constexpr auto MNLdsLayer =
MakeOGradLdsBlockDescriptor<Problem>().get_element_space_size(); (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
return smem_size_do;
constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack * MNLdsLayer>{},
number<MNPerBlock / MNLdsLayer>{},
number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock * MNLdsLayer>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto x_lds_block_desc_permuted = transform_tensor_descriptor(
x_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<MNPerBlock / MNLdsLayer>{},
number<KPerBlock / KPack * MNLdsLayer>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto x_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
x_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<KPerBlock / KPack>{}, number<MNLdsLayer>{})),
make_pass_through_transform(number<MNPerBlock / MNLdsLayer>{}),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto x_lds_block_desc = transform_tensor_descriptor(
x_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<MNPerBlock / MNLdsLayer>{}, number<MNLdsLayer>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return x_lds_block_desc;
} }
template <typename Problem> template <typename Problem,
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOGradT() index_t MNPerBlock,
index_t KPerBlock,
index_t KPack,
index_t KPackT>
CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor()
{ {
constexpr index_t smem_size_dot = [&]() { // kfold and mpair dimension is not always required.
if constexpr(OGradLoadOnce && !OGradTLoadOnce) // more dimension in merge_transform increase the difficulty of generating immarg offset
return 0; // for compiler.
else constexpr auto MNPerXDL = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
return sizeof(typename Problem::OGradDataType) * constexpr auto kBlockSize = Problem::kBlockSize;
MakeOGradTLdsBlockDescriptor<Problem>().get_element_space_size();
}(); constexpr auto MN0 = MNPerBlock / KPack;
return smem_size_dot; constexpr auto MN1 = KPack;
constexpr auto KThreadWrite = kBlockSize / MN0;
constexpr auto K0Number = KPerBlock / KPackT;
constexpr auto K0PerThreadWrite = K0Number / KThreadWrite;
constexpr auto KThreadRead = get_warp_size() / MNPerXDL; // assume 32x32x8 mfma
constexpr auto K0PerThreadRead = K0Number / KThreadRead;
constexpr auto kfold = (KPackT * MN0 * 2 > 128) ? 1 : 128 / (KPackT * MN0 * 2);
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mnpair<=n0
constexpr auto mnpair =
(KPackT * MNPerXDL * 2 > 128)
? 1
: ((128 / (KPackT * MNPerXDL * 2)) > MN0 ? MN0 : 128 / (KPackT * MNPerXDL * 2));
constexpr auto xt_lds_block_desc_raw = make_naive_tensor_descriptor(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * MN1>{},
number<kfold * MN0 / mnpair>{},
number<mnpair>{},
KPackT),
make_tuple(number<KPackT * kfold * MN0 * KThreadReadPerm * MN1 * K0PerThreadWrite>{},
number<KPackT * kfold * MN0 * KThreadReadPerm * MN1>{},
number<KPackT * kfold * MN0>{},
number<KPackT * mnpair>{},
number<KPackT>{},
number<1>{}),
number<KPackT>{},
number<1>{});
constexpr auto xt_lds_block_desc_permuted = transform_tensor_descriptor(
xt_lds_block_desc_raw,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(
make_tuple(number<KThreadReadPerm * MN1>{}, number<kfold * MN0 / mnpair>{})),
make_pass_through_transform(number<mnpair>{}),
make_pass_through_transform(KPackT)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
constexpr auto xt_lds_block_desc_unmerged = transform_tensor_descriptor(
xt_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<MN1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<MN0 / mnpair>{})),
make_pass_through_transform(number<mnpair>{}),
make_pass_through_transform(KPackT)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
constexpr auto xt_lds_block_desc = transform_tensor_descriptor(
xt_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
number<KPackT>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<MN0 / mnpair>{}, number<mnpair>{}, number<MN1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return xt_lds_block_desc;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeSGrad() CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor()
{ {
constexpr index_t smem_size_ds = constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
sizeof(typename Problem::GemmDataType) * constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
MakeSGradLdsBlockDescriptor<Problem>().get_element_space_size(); constexpr index_t kKPack = GetSmemKPackK<Problem>();
return smem_size_ds;
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeBias() CK_TILE_HOST_DEVICE static constexpr auto MakeKRegSliceBlockDescriptor()
{ {
constexpr index_t smem_size_bias = [&]() { using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
return sizeof(typename Problem::BiasDataType) * using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
MakeBiasTLdsBlockDescriptor<Problem>().get_element_space_size();
else
return 0;
}();
return smem_size_bias;
}
template <typename Problem> constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
{
constexpr index_t smem_size_q = GetSmemSizeQ<Problem>();
constexpr index_t smem_size_qt = GetSmemSizeQT<Problem>();
constexpr index_t smem_size_k = GetSmemSizeK<Problem>();
constexpr index_t smem_size_kt = GetSmemSizeKT<Problem>();
constexpr index_t smem_size_v = GetSmemSizeV<Problem>();
constexpr index_t smem_size_do = GetSmemSizeOGrad<Problem>();
constexpr index_t smem_size_dot = GetSmemSizeOGradT<Problem>();
constexpr index_t smem_size_ds = GetSmemSizeSGrad<Problem>();
constexpr index_t smem_size_bias = GetSmemSizeBias<Problem>();
constexpr index_t smem_size_transpose = max(smem_size_ds, smem_size_bias);
index_t smem_size = 0;
if constexpr(QLoadOnce && OGradLoadOnce)
smem_size += smem_size_q + smem_size_qt + smem_size_do + smem_size_dot +
smem_size_transpose; // 1~4 & 10
else if(QLoadOnce && !OGradLoadOnce && !OGradTLoadOnce)
smem_size += smem_size_q + smem_size_qt +
max(smem_size_do,
smem_size_dot,
smem_size_transpose); // 5/7/11 TODO: Multiple buffers strategy
else if(!QLoadOnce && !QTLoadOnce && OGradLoadOnce)
smem_size += smem_size_do + smem_size_dot +
max(smem_size_q,
smem_size_qt,
smem_size_transpose); // 6/8/12 TODO: Multiple buffers strategy
else if(!QLoadOnce && !QTLoadOnce && !OGradLoadOnce && !OGradTLoadOnce)
smem_size += max(smem_size_q,
smem_size_qt,
smem_size_do,
smem_size_dot,
smem_size_transpose); // 9/13 TODO: Multiple buffers strategy
// 14/15 needs to be adjusted
if constexpr(KLoadOnce)
smem_size += (smem_size_k + smem_size_kt); // 1~13
else
smem_size =
max(smem_size_k, smem_size_kt, smem_size); // 14/15 TODO: Multiple buffers strategy
return max(smem_size, smem_size_v); // 15 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
} constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
template <typename Problem, typename BlockGemm> constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution() constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
{
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr auto k_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane; constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
constexpr index_t N0 = NWarp; k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * 2; constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane;
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / 2;
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM);
return make_static_tile_distribution( return k_block_dstr;
tile_distribution_encoding<sequence<N0, N1>,
tuple<sequence<M0, M1, M2, M3, M4>>,
tuple<sequence<1, 0>, sequence<1, 0>>,
tuple<sequence<1, 0>, sequence<3, 1>>,
sequence<1, 1, 1>,
sequence<0, 2, 4>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor()
{ {
using VDataType = remove_cvref_t<typename Problem::VDataType>; using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = 16 / sizeof(VDataType); constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution( constexpr auto k_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>, tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<0, 1>>,
tuple<sequence<1>, sequence<2, 0>>, tuple<sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 1>>{}); sequence<0, 0>>{};
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
return k_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kVPack = GetSmemKPackV<Problem>();
constexpr index_t kKPerBlock = [&]() {
if constexpr(QLoadOnce)
return Problem::BlockFmhaShape::kQKHeaddim;
else
return Problem::BlockFmhaShape::kK0;
}();
constexpr index_t K1 = GetAlignmentQ<Problem>(); return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>();
constexpr index_t K0 = kKPerBlock / K1; }
constexpr index_t M2 = get_warp_size() / K0;
// coalesce reading for each blocks
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution( template <typename Problem>
tile_distribution_encoding<sequence<>, CK_TILE_HOST_DEVICE static constexpr auto MakeVRegSliceBlockDescriptor()
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, {
tuple<sequence<1>, sequence<1, 2>>, using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>;
tuple<sequence<1>, sequence<2, 0>>, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto v_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 1>>{}); sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
return v_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = [&]() { constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
if constexpr(KLoadOnce)
return Problem::BlockFmhaShape::kQKHeaddim;
else
return Problem::BlockFmhaShape::kK0;
}();
constexpr index_t K1 = GetAlignmentK<Problem>(); constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution( constexpr auto v_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>, tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<0, 1>>,
tuple<sequence<1>, sequence<2, 0>>, tuple<sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 1>>{}); sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
return v_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKRegWriteBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = [&]() {
if constexpr(OGradLoadOnce)
return Problem::BlockFmhaShape::kVHeaddim;
else
return Problem::BlockFmhaShape::kK2;
}();
constexpr index_t K1 = GetAlignmentOGrad<Problem>(); constexpr index_t K1 = GetAlignmentK<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0; constexpr index_t N2 = GetTransposedAlignmentK<Problem>();
// coalesce reading for each blocks constexpr index_t N1 = get_warp_size() / K0;
constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<>, tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>, tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>, sequence<2, 1>,
sequence<0, 1>>{}); sequence<1, 2>>{});
} }
template <typename DataType, index_t MPerBlock, index_t KPerBlock> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePreXDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKLdsWriteBlockDescriptor()
{ {
constexpr index_t K1 = 16 / sizeof(DataType); // Hold all data
constexpr index_t K0 = KPerBlock / K1; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t M2 = 1; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t M1 = get_warp_size();
constexpr index_t M0 = MPerBlock / M1;
return make_static_tile_distribution( constexpr index_t kKPack = GetSmemKPackK<Problem>();
tile_distribution_encoding<sequence<>, constexpr index_t kKPackT = GetSmemKPackKT<Problem>();
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1>>, return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
tuple<sequence<0>, sequence<1>>,
sequence<1, 2, 2>,
sequence<2, 0, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePreODramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeKTLdsReadBlockDescriptor()
{ {
using ODataType = remove_cvref_t<typename Problem::ODataType>; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kBlockSize = Problem::kBlockSize; auto shuffled_k_lds_block_desc = MakeShuffledKLdsWriteBlockDescriptor<Problem>();
constexpr index_t kKPerBlock = Problem::kVHeaddim;
return MakePreXDramTileDistribution<ODataType, kBlockSize, kKPerBlock>(); return transform_tensor_descriptor(
shuffled_k_lds_block_desc,
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
make_pass_through_transform(number<kKPerBlock>{})),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePreOGradDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeKTRegBlockDescriptor()
{ {
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>; using BlockGemm = remove_cvref_t<decltype(GetSGradKTBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
constexpr index_t kKPerBlock = Problem::kVHeaddim; constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
return MakePreXDramTileDistribution<OGradDataType, kBlockSize, kKPerBlock>(); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto kt_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
kt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto kt_block_dstr = make_static_tile_distribution(kt_block_dstr_encode);
return kt_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeQTDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = [&]() {
if constexpr(QTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK3;
}();
constexpr index_t N1 = GetTransposedAlignmentQ<Problem>();
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackQ<Problem>(); constexpr index_t kKPack = GetSmemKPackQ<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution( return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQTRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSliceBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
constexpr index_t kKPerBlock = [&]() { using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
if constexpr(QTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK3;
}();
constexpr index_t N1 = GetTransposedAlignmentQ<Problem>(); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
constexpr index_t N0 = kNPerBlock / N1; constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution( constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
tile_distribution_encoding<sequence<>, constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>, constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
tuple<sequence<0>, sequence<1, 0, 2>>, constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto q_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>, sequence<1, 2>,
sequence<1, 3>>{}); sequence<0, 0>>{};
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode);
return q_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeKTDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQRegWriteBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = [&]() {
if constexpr(KTLoadOnce)
return Problem::BlockFmhaShape::kN0;
else
return Problem::BlockFmhaShape::kK4;
}();
constexpr index_t N1 = GetTransposedAlignmentK<Problem>(); constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t K1 = GetAlignmentQ<Problem>();
static_assert(total_pixels % N1 == 0); // TODO: this is not always true? constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t K3 = total_pixels / N1; constexpr index_t N2 = GetTransposedAlignmentQ<Problem>();
constexpr index_t kKPack = GetSmemKPackK<Problem>(); constexpr index_t N1 = get_warp_size() / K0;
static_assert(kKPack % K3 == 0); constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<>, tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>, tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<2>, sequence<2, 1, 2>>, tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>, tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>, sequence<2, 1>,
sequence<3, 1>>{}); sequence<1, 2>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKTRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQLdsWriteBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; // Hold full block data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = [&]() { constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
if constexpr(KTLoadOnce)
return Problem::BlockFmhaShape::kN0;
else
return Problem::BlockFmhaShape::kK4;
}();
constexpr index_t N1 = GetTransposedAlignmentK<Problem>(); constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr index_t N0 = kNPerBlock / N1; constexpr index_t kKPackT = GetSmemKPackQT<Problem>();
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution( return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeOGradTDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeQTLdsReadBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; // Hold full block data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = [&]() { constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
if constexpr(OGradTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK1;
}();
constexpr index_t N1 = GetTransposedAlignmentOGrad<Problem>();
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; auto shuffled_q_lds_block_desc = MakeShuffledQLdsWriteBlockDescriptor<Problem>();
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution( return transform_tensor_descriptor(
tile_distribution_encoding<sequence<>, shuffled_q_lds_block_desc,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>, make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
tuple<sequence<2>, sequence<2, 1, 2>>, make_pass_through_transform(number<kKPerBlock>{})),
tuple<sequence<0>, sequence<1, 0, 2>>, make_tuple(sequence<1>{}, sequence<0>{}),
sequence<2, 1>, make_tuple(sequence<0>{}, sequence<1>{}));
sequence<3, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradTRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeQTRegSliceBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; using BlockGemm = remove_cvref_t<decltype(GetSGradTQTBlockGemm<Problem>())>;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
constexpr index_t kKPerBlock = [&]() { using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
if constexpr(OGradTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK1;
}();
constexpr index_t N1 = GetTransposedAlignmentOGrad<Problem>(); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
constexpr index_t N0 = kNPerBlock / N1; constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{});
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution( constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
tile_distribution_encoding<sequence<>, constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>, constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
tuple<sequence<0>, sequence<1, 0, 2>>, constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto qt_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<1, 3>>{}); sequence<0, 0>>{};
constexpr auto qt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
qt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto qt_block_dstr = make_static_tile_distribution(qt_block_dstr_encode);
return qt_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeBiasTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeSGradTRegSliceBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; using BlockGemm = remove_cvref_t<decltype(GetSGradTQTBlockGemm<Problem>())>;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t N1 = GetTransposedAlignmentBias<Problem>(); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
constexpr index_t N0 = kNPerBlock / N1; // P constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{});
constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true? constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
constexpr index_t M3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackBias<Problem>();
static_assert(kKPack % M3 == 0);
constexpr index_t M2 = kKPack / M3; // TODO: this dimention could be outside single wave
constexpr index_t M1 = get_warp_size() / (M2 * N0);
constexpr index_t M0 = kBlockSize / get_warp_size();
static_assert(kMPerBlock == M0 * M1 * M2 * M3);
return make_static_tile_distribution( constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
tile_distribution_encoding<sequence<>, constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
tuple<sequence<M0, M1, M2, M3>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2, 1>>, constexpr auto dst_block_outer_dstr_encoding =
tuple<sequence<0>, sequence<1, 0, 2>>, tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>, sequence<1, 2>,
sequence<3, 1>>{}); sequence<0, 0>>{};
constexpr auto dst_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
dst_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto dst_block_dstr = make_static_tile_distribution(dst_block_dstr_encode);
return dst_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsWriteBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; using LSEDType = remove_cvref_t<typename Problem::DDataType>;
constexpr index_t kMPack = 16 / sizeof(LSEDType);
constexpr index_t N1 = GetTransposedAlignmentBias<Problem>(); constexpr auto lsed_lds_block_desc =
constexpr index_t N0 = kNPerBlock / N1; make_naive_tensor_descriptor(make_tuple(number<kMPerBlock>{}),
constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; make_tuple(number<1>{}),
static_assert(total_pixels % N1 == 0); // TODO: this is not always true? number<kMPack>{},
constexpr index_t M3 = total_pixels / N1; number<1>{});
constexpr index_t kKPack = GetSmemKPackBias<Problem>();
static_assert(kKPack % M3 == 0);
constexpr index_t M2 = kKPack / M3; // TODO: this dimention could be outside single wave
constexpr index_t M1 = get_warp_size() / (M2 * N0);
constexpr index_t M0 = kBlockSize / get_warp_size();
return make_static_tile_distribution( return lsed_lds_block_desc;
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2, M3>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2, 1>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<1, 3>>{});
} }
template <typename BlockGemm> template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsReadBlockDescriptor()
{ {
using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile()); constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
return c_block_tensor_type::get_tile_distribution(); using WG = remove_cvref_t<decltype(config.template at<0>())>;
} constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane;
constexpr index_t N0 = NWarp;
// M4 *2 and M2 /2 when swizzle mode enabled
constexpr index_t SwizzleConfig = WG::kM == 16 ? 1 : 2;
// constexpr index_t SwizzleConfig = 1;
constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * SwizzleConfig;
constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane;
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / SwizzleConfig;
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<N0, N1>,
tuple<sequence<M0, M1, M2, M3, M4>>,
tuple<sequence<1, 0>, sequence<1, 0>>,
tuple<sequence<1, 0>, sequence<3, 1>>,
sequence<1, 1, 1>,
sequence<0, 2, 4>>{});
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptor()
{ {
using BlockGemmProblem = // Hold full block data
BlockGemmPipelineProblem<typename Problem::QDataType, constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
typename Problem::KDataType, constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>;
constexpr auto warp_gemm = []() { constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::AccDataType, float>)
{
return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::AccDataType, float>)
{
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{};
}
}();
using BlockGemmPolicy = return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::QDataType, }
typename Problem::KDataType,
typename Problem::AccDataType, template <typename Problem>
typename Problem::BlockFmhaShape::Gemm0BlockWarps, CK_TILE_HOST_DEVICE static constexpr auto MakeOGradRegSliceBlockDescriptor()
decltype(warp_gemm)>; {
using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto do_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto do_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
do_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto do_block_dstr = make_static_tile_distribution(do_block_dstr_encode);
return do_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradRegWriteBlockDescriptor()
{ {
using BlockGemmProblem = constexpr index_t kBlockSize = Problem::kBlockSize;
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>>;
using WarpGemm = constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::OGradDataType, constexpr index_t K1 = GetAlignmentOGrad<Problem>();
typename Problem::AccDataType, constexpr index_t K0 = kKPerBlock / K1;
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}), constexpr index_t N2 = GetTransposedAlignmentOGrad<Problem>();
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), constexpr index_t N1 = get_warp_size() / K0;
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), constexpr index_t N0 = kBlockSize / get_warp_size();
true>;
using BlockGemmPolicy = return make_static_tile_distribution(
BlockGemmARegBSmemCRegV1CustomPolicy<typename Problem::GemmDataType, tile_distribution_encoding<sequence<>,
typename Problem::OGradDataType, tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
typename Problem::AccDataType, tuple<sequence<1>, sequence<1, 2>>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps, tuple<sequence<0>, sequence<1, 0>>,
WarpGemm>; sequence<2, 1>,
return BlockGemmARegBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; sequence<1, 2>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradLdsWriteBlockDescriptor()
{ {
using BlockGemmProblem = // Hold all data
BlockGemmPipelineProblem<typename Problem::OGradDataType, constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
typename Problem::VDataType, constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>>;
constexpr auto warp_gemm = []() { constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
if constexpr(std::is_same_v<typename Problem::OGradDataType, half_t> && constexpr index_t kKPackT = GetSmemKPackOGradT<Problem>();
std::is_same_v<typename Problem::VDataType, half_t> &&
std::is_same_v<typename Problem::AccDataType, float>)
{
return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{};
}
else if constexpr(std::is_same_v<typename Problem::OGradDataType, bf16_t> &&
std::is_same_v<typename Problem::VDataType, bf16_t> &&
std::is_same_v<typename Problem::AccDataType, float>)
{
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{};
}
}();
using BlockGemmPolicy = return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
BlockGemmASmemBRegCRegV1CustomPolicy<typename Problem::OGradDataType, }
typename Problem::VDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
decltype(warp_gemm)>;
return BlockGemmASmemBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
// {
// using BlockGemmProblem =
// BlockGemmPipelineProblem<typename Problem::OGradDataType,
// typename Problem::VDataType,
// typename Problem::AccDataType,
// Problem::kBlockSize,
// TileGemmShape<Problem::BlockFmhaShape::kM0,
// Problem::BlockFmhaShape::kN0,
// Problem::BlockFmhaShape::kK2>>;
// constexpr auto warp_gemm = []() {
// if constexpr(std::is_same_v<typename Problem::OGradDataType, half_t> &&
// std::is_same_v<typename Problem::VDataType, half_t> &&
// std::is_same_v<typename Problem::AccDataType, float>)
// {
// return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{};
// }
// else if constexpr(std::is_same_v<typename Problem::OGradDataType, bf16_t> &&
// std::is_same_v<typename Problem::VDataType, bf16_t> &&
// std::is_same_v<typename Problem::AccDataType, float>)
// {
// return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{};
// }
// }();
// using BlockGemmPolicy =
// BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::OGradDataType,
// typename Problem::VDataType,
// typename Problem::AccDataType,
// typename
// Problem::BlockFmhaShape::Gemm2BlockWarps,
// decltype(warp_gemm)>;
// return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
// }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsReadBlockDescriptor()
{
// Hold all data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
auto shuffled_do_lds_block_desc = MakeShuffledOGradLdsWriteBlockDescriptor<Problem>();
return transform_tensor_descriptor(
shuffled_do_lds_block_desc,
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
make_pass_through_transform(number<kKPerBlock>{})),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTRegSliceBlockDescriptor()
{ {
using BlockGemmProblem = using BlockGemm = remove_cvref_t<decltype(GetPTOGradTBlockGemm<Problem>())>;
BlockGemmPipelineProblem<typename Problem::GemmDataType, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
typename Problem::QDataType, using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>>;
using WarpGemm = constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
typename Problem::QDataType,
typename Problem::AccDataType, constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}), // constexpr index_t kNPerBlock = 32;
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}), constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true>; constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
using BlockGemmPolicy = constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
BlockGemmARegBSmemCRegV1CustomPolicy<typename Problem::GemmDataType,
typename Problem::QDataType, constexpr auto dot_block_outer_dstr_encoding =
typename Problem::AccDataType, tile_distribution_encoding<sequence<MWarp>,
typename Problem::BlockFmhaShape::Gemm3BlockWarps, tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
WarpGemm>; tuple<sequence<0, 1>>,
return BlockGemmARegBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto dot_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
dot_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto dot_block_dstr = make_static_tile_distribution(dot_block_dstr_encode);
return dot_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto MakePTRegSliceBlockDescriptor()
{ {
using BlockGemmProblem = using BlockGemm = remove_cvref_t<decltype(GetPTOGradTBlockGemm<Problem>())>;
BlockGemmPipelineProblem<typename Problem::GemmDataType, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
typename Problem::KDataType, using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>>;
using WarpGemm = constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
typename Problem::KDataType,
typename Problem::AccDataType, constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}), constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}), constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
true>; constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
using BlockGemmPolicy =
BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::GemmDataType, constexpr auto pt_block_outer_dstr_encoding =
typename Problem::KDataType, tile_distribution_encoding<sequence<NWarp>,
typename Problem::AccDataType, tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
typename Problem::BlockFmhaShape::Gemm4BlockWarps, tuple<sequence<1, 0>>,
WarpGemm>; tuple<sequence<1, 0>>,
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto pt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
pt_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto pt_block_dstr = make_static_tile_distribution(pt_block_dstr_encode);
return pt_block_dstr;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSGradLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPack = GetSmemKPackSGrad<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSGradRegSliceBlockDescriptor()
{
using BlockGemm = remove_cvref_t<decltype(GetSGradKTBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK4;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto ds_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto ds_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
ds_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto ds_block_dstr = make_static_tile_distribution(ds_block_dstr_encode);
return ds_block_dstr;
}
template <typename Problem, typename PTOutTensor, typename PInTensor>
CK_TILE_DEVICE static constexpr void PTFromGemm0CToGemm1A(PTOutTensor& pt_out,
const PInTensor& p_in)
{
if constexpr(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}) == 16)
{
using BlockGemm = remove_cvref_t<decltype(GetPTOGradTBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
auto pt_warp_tensor =
make_static_distributed_tensor<typename Problem::GemmDataType>(CWarpDstr{});
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
pt_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
pt_out.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
pt_warp_tensor.get_thread_buffer());
});
});
}
else
{
pt_out.get_thread_buffer() = p_in.get_thread_buffer();
}
}
template <typename Problem, typename SGradTOutTensor, typename SGradInTensor>
CK_TILE_DEVICE static constexpr void SGradTFromGemm2CToGemm3A(SGradTOutTensor& dst_out,
const SGradInTensor& ds_in)
{
if constexpr(Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}) == 16)
{
using BlockGemm = remove_cvref_t<decltype(GetSGradTQTBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
auto dst_warp_tensor =
make_static_distributed_tensor<typename Problem::GemmDataType>(CWarpDstr{});
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
dst_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
dst_out.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
dst_warp_tensor.get_thread_buffer());
});
});
}
else
{
dst_out.get_thread_buffer() = ds_in.get_thread_buffer();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t N1 = GetAlignmentBias<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t M2 = GetTransposedAlignmentBias<Problem>();
constexpr index_t M1 = get_warp_size() / N0;
constexpr index_t M0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsBlockDescriptor()
{
// Hold full block data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPack = GetSmemKPackBias<Problem>();
constexpr index_t kKPackT = GetSmemKPackBiasT<Problem>();
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kMPerBlock, kKPack, kKPackT>();
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasSTileDistribution()
{
using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile());
return c_block_tensor_type::get_tile_distribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeQ()
{
constexpr index_t smem_size_q = sizeof(typename Problem::QDataType) *
MakeQLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_q;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeQT()
{
constexpr index_t smem_size_qt =
sizeof(typename Problem::QDataType) *
MakeShuffledQLdsWriteBlockDescriptor<Problem>().get_element_space_size();
return smem_size_qt;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeK()
{
constexpr index_t smem_size_k =
sizeof(typename Problem::KDataType) *
MakeKLdsWriteBlockDescriptor<Problem>().get_element_space_size();
return smem_size_k;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeKT()
{
constexpr index_t smem_size_kt =
sizeof(typename Problem::KDataType) *
MakeKTLdsReadBlockDescriptor<Problem>().get_element_space_size();
return smem_size_kt;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeLSE()
{
constexpr index_t smem_size_lse =
sizeof(typename Problem::LSEDataType) *
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
return smem_size_lse;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeD()
{
constexpr index_t smem_size_d =
sizeof(typename Problem::DDataType) *
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
return smem_size_d;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeV()
{
constexpr index_t smem_size_v =
sizeof(typename Problem::VDataType) *
MakeVLdsWriteBlockDescriptor<Problem>().get_element_space_size();
return smem_size_v;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeOGrad()
{
constexpr index_t smem_size_do =
sizeof(typename Problem::OGradDataType) *
MakeOGradLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_do;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeOGradT()
{
constexpr index_t smem_size_dot =
sizeof(typename Problem::OGradDataType) *
MakeShuffledOGradLdsWriteBlockDescriptor<Problem>().get_element_space_size();
return smem_size_dot;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeSGrad()
{
constexpr index_t smem_size_ds =
sizeof(typename Problem::GemmDataType) *
MakeSGradLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_ds;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeBias()
{
constexpr index_t smem_size_bias = [&]() {
if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return sizeof(typename Problem::BiasDataType) *
MakeBiasLdsBlockDescriptor<Problem>().get_element_space_size();
else
return 0;
}();
return smem_size_bias;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_q = GetSmemSizeQ<Problem>();
constexpr index_t smem_size_qt = GetSmemSizeQT<Problem>();
constexpr index_t smem_size_lse = GetSmemSizeLSE<Problem>();
constexpr index_t smem_size_k = GetSmemSizeK<Problem>();
constexpr index_t smem_size_kt = GetSmemSizeKT<Problem>();
constexpr index_t smem_size_v = GetSmemSizeV<Problem>();
constexpr index_t smem_size_do = GetSmemSizeOGrad<Problem>();
constexpr index_t smem_size_dot = GetSmemSizeOGradT<Problem>();
constexpr index_t smem_size_d = GetSmemSizeD<Problem>();
constexpr index_t smem_size_ds = GetSmemSizeSGrad<Problem>();
constexpr index_t smem_size_bias = GetSmemSizeBias<Problem>();
constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt;
constexpr index_t smem_size_stage0_1 = smem_size_v;
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot +
smem_size_do + smem_size_lse + smem_size_d +
max(smem_size_bias, smem_size_ds);
return max(smem_size_stage0_0, smem_size_stage0_1, smem_size_stage1);
}
template <typename Problem_>
struct HotLoopScheduler
{
using Problem = Problem_;
template <index_t GemmStage>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler()
{
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<0>()
{
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Comp: Q x K
constexpr index_t VMEM_READ_INST =
Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ;
constexpr index_t LDS_READ_INST = OGradT_LDS_READ;
constexpr index_t MFMA_INST = Gemm0MFMA;
// Evenly distributed to relieve SQ->TA FIFO pressure
constexpr index_t MFMA_PER_VMEM_READ = MFMA_INST / VMEM_READ_INST;
constexpr index_t MFMA_Remainder = MFMA_INST - MFMA_PER_VMEM_READ * VMEM_READ_INST;
// To hide instruction issue latency
constexpr index_t LDS_READ_PER_MFMA = LDS_READ_INST / MFMA_INST;
static_for<0, VMEM_READ_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
static_for<0, MFMA_PER_VMEM_READ, 1>{}([&](auto j) {
ignore = j;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
});
});
static_for<0, MFMA_Remainder, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
});
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<1>()
{
// Mem: Q^T LDS load
// Comp: OGrad x V
constexpr index_t LDS_READ_INST = QT_LDS_READ;
constexpr index_t MFMA_INST = Gemm1MFMA;
// To hide instruction issue latency
constexpr index_t LDS_READ_PER_MFMA = LDS_READ_INST / MFMA_INST;
static_for<0, MFMA_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
});
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<2>()
{
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
// Comp: PT x OGrad
constexpr index_t LDS_WRITE_INST = Q_LDS_WRITE + QT_LDS_WRITE + OGrad_LDS_WRITE +
OGradT_LDS_WRITE + LSE_LDS_WRITE + D_LDS_WRITE;
constexpr index_t MFMA_INST = Gemm2MFMA;
// To hide instruction issue latency
constexpr index_t LDS_WRITE_PER_MFMA = LDS_WRITE_INST / MFMA_INST;
static_for<0, MFMA_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS write
});
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<3>()
{
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT
constexpr index_t LDS_WRITE_INST = SGradT_LDS_WRITE;
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ + LSE_LDS_READ;
constexpr index_t MFMA_INST = Gemm3MFMA;
// To hide instruction issue latency
constexpr index_t LDS_WRITE_PER_MFMA =
LDS_WRITE_INST / MFMA_INST >= 1 ? LDS_WRITE_INST / MFMA_INST : 1;
constexpr index_t MFMA_INST_LDS_WRITE = LDS_WRITE_INST / LDS_WRITE_PER_MFMA;
constexpr index_t LDS_READ_PER_MFMA =
(MFMA_INST - MFMA_INST_LDS_WRITE) > 0
? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE) > 0
? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE)
: 1
: 0;
static_for<0, MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS Write
});
static_for<0, MFMA_INST - MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read
});
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<4>()
{
// Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ + D_LDS_READ;
constexpr index_t MFMA_INST = Gemm4MFMA;
// To hide instruction issue latency
constexpr index_t LDS_READ_PER_MFMA =
LDS_READ_INST / MFMA_INST > 0 ? LDS_READ_INST / MFMA_INST : 1;
static_for<0, MFMA_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read
});
}
private:
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0;
static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim;
static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4;
static constexpr index_t WarpGemmM =
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static constexpr index_t WarpGemmN =
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{});
static constexpr index_t WarpGemmK = WarpGemmM == 16 ? 16 : 8;
static constexpr index_t Gemm4MWarp =
Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
static constexpr index_t Gemm4NWarp =
Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
// Compute
static constexpr index_t Gemm0MFMA =
kM0 * kN0 * kQKHeaddim /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm1MFMA =
kM0 * kN0 * kVHeaddim /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm2MFMA =
kN0 * kVHeaddim * kM0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm3MFMA =
kN0 * kQKHeaddim * kM0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm4MFMA =
kM0 * kQKHeaddim * kN0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
// VMEM
static constexpr index_t Q_VMEM_READ =
kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ<Problem>();
static constexpr index_t OGrad_VMEM_READ =
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t LSE_VMEM_READ = 1;
static constexpr index_t D_VMEM_READ = 1;
// LDS Read
static constexpr index_t OGradT_LDS_READ =
kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
static constexpr index_t QT_LDS_READ =
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
static constexpr index_t SGradT_LDS_READ_P1 =
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
static constexpr index_t Q_LDS_READ =
kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ<Problem>();
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
static constexpr index_t SGradT_LDS_READ_P2 =
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
static constexpr index_t OGrad_LDS_READ =
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
// LDS Write
static constexpr index_t Q_LDS_WRITE =
kM0 * kQKHeaddim / Problem::kBlockSize / GetAlignmentQ<Problem>();
static constexpr index_t QT_LDS_WRITE =
kM0 * kQKHeaddim / kBlockSize / GetTransposedAlignmentQ<Problem>();
static constexpr index_t OGrad_LDS_WRITE =
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t OGradT_LDS_WRITE =
kM0 * kVHeaddim / kBlockSize / GetTransposedAlignmentOGrad<Problem>();
static constexpr index_t LSE_LDS_WRITE = 1;
static constexpr index_t D_LDS_WRITE = 1;
static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize;
};
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -8,9 +8,8 @@ namespace ck_tile { ...@@ -8,9 +8,8 @@ namespace ck_tile {
// This class is used for codegen pattern matching // This class is used for codegen pattern matching
enum class BlockFmhaBwdPipelineEnum enum class BlockFmhaBwdPipelineEnum
{ {
KSKTSVR = 0, KRKTRVR_IGLP = 0,
QSKSVROGradS, KRKTRVR,
KSVR,
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -24,7 +24,9 @@ template <typename QDataType_, ...@@ -24,7 +24,9 @@ template <typename QDataType_,
typename BiasGradDataType_, typename BiasGradDataType_,
typename BlockFmhaShape_, typename BlockFmhaShape_,
bool kIsGroupMode_, bool kIsGroupMode_,
bool kIsDeterministic_,
typename FmhaMask_, typename FmhaMask_,
typename FmhaDropout_,
typename Traits_> typename Traits_>
struct BlockFmhaBwdPipelineProblem struct BlockFmhaBwdPipelineProblem
{ {
...@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem ...@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem
using BiasGradDataType = remove_cvref_t<BiasGradDataType_>; using BiasGradDataType = remove_cvref_t<BiasGradDataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>; using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>; using FmhaMask = remove_cvref_t<FmhaMask_>;
using FmhaDropout = remove_cvref_t<FmhaDropout_>;
using Traits = remove_cvref_t<Traits_>; using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
// attributes from traits // attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
...@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem ...@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
...@@ -88,4 +91,35 @@ struct BlockFmhaBwdOGradDotOPipelineProblem ...@@ -88,4 +91,35 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
template <typename AccDataType_,
typename QGradDataType_,
index_t kBlockSize_,
index_t kM0_,
index_t kN0_,
index_t kQKHeaddim_,
bool kIsGroupMode_,
bool kIsDeterministic_,
typename Traits_>
struct BlockFmhaBwdConvertQGradPipelineProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using QGradDataType = remove_cvref_t<QGradDataType_>;
using Traits = remove_cvref_t<Traits_>;
static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
"kBlockSize should be divisible by get_warp_size()");
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN0 = kN0_;
static constexpr index_t kQKHeaddim = kQKHeaddim_;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile } // namespace ck_tile
...@@ -86,4 +86,14 @@ struct TileFmhaBwdOGradDotOTraits ...@@ -86,4 +86,14 @@ struct TileFmhaBwdOGradDotOTraits
static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
index_t kBlockPerCu_ = 2 /* hint to occupancy */>
struct TileFmhaBwdConvertQGradTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
} // namespace ck_tile } // namespace ck_tile
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
struct BlockGemmARegBRegCRegV1
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B
template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
// M->N Warp
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
// check ABC-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"A distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(b_block_dstr_encode)>,
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"B distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"C distribution is wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor, b_block_tensor);
return c_block_tensor;
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename AType_,
typename BType_,
typename CType_,
typename BlockWarps_,
typename WarpGemm_>
struct BlockGemmARegBRegCRegV1CustomPolicy
{
using AType = remove_cvref_t<AType_>;
using BType = remove_cvref_t<BType_>;
using CType = remove_cvref_t<CType_>;
using BlockWarps = remove_cvref_t<BlockWarps_>;
static constexpr index_t kMWarps = BlockWarps::at(number<0>{});
static constexpr index_t kNWarps = BlockWarps::at(number<1>{});
static constexpr index_t kKWarps = BlockWarps::at(number<2>{});
using WarpGemm = remove_cvref_t<WarpGemm_>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
return make_tuple(WarpGemm{}, kMWarps, kNWarps);
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment