Unverified Commit 2cab8d39 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

CK Tile FA Training kernels (#1286)



* FA fwd dropout

* FA bwd

* epilogue reuse

* CMakeLists update

* [CK_TILE] support alibi (#1269)

* add alibi support

* fix code

* update code based on comment

* Support more hdim

* fix fp8 bias

* support seqlen_k=0 case

* remove unused printf

* fix format

---------
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>

* now fwd/bwd can build

* bwd alibi

* add bwd validation stream_config

* update generated filenames

* update bwd kernel launch

* CK_TILE_HOST_DEVICE in philox

* Transpose -> transpose

* format

* format

* format

* Generate the instance for FA required

* format

* fix error in WarpGemm

---------

Co-authored-by: danyao12 <danyao12>
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
parent 76827d82
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockFmhaBwdPipelineEnum
{
KSKTSVR = 0,
QSKSVROGradS,
KSVR,
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
typename GemmDataType_,
typename LSEDataType_,
typename AccDataType_,
typename DDataType_,
typename BiasDataType_,
typename RandValOutputDataType_,
typename ODataType_,
typename OGradDataType_,
typename QGradDataType_,
typename KGradDataType_,
typename VGradDataType_,
typename BiasGradDataType_,
typename BlockFmhaShape_,
bool kIsGroupMode_,
typename FmhaMask_,
typename Traits_>
struct BlockFmhaBwdPipelineProblem
{
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using GemmDataType = remove_cvref_t<GemmDataType_>;
using LSEDataType = remove_cvref_t<LSEDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using DDataType = remove_cvref_t<DDataType_>;
using BiasDataType = remove_cvref_t<BiasDataType_>;
using RandValOutputDataType = remove_cvref_t<RandValOutputDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using OGradDataType = remove_cvref_t<OGradDataType_>;
using QGradDataType = remove_cvref_t<QGradDataType_>;
using KGradDataType = remove_cvref_t<KGradDataType_>;
using VGradDataType = remove_cvref_t<VGradDataType_>;
using BiasGradDataType = remove_cvref_t<BiasGradDataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
template <typename ODataType_,
typename OGradDataType_,
typename DDataType_,
index_t kBlockSize_,
index_t kVHeaddim_,
bool kIsGroupMode_,
typename Traits_>
struct BlockFmhaBwdOGradDotOPipelineProblem
{
using ODataType = remove_cvref_t<ODataType_>;
using OGradDataType = remove_cvref_t<OGradDataType_>;
using DDataType = remove_cvref_t<DDataType_>;
using Traits = remove_cvref_t<Traits_>;
static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
"kBlockSize should be divisible by get_warp_size()");
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kVHeaddim = kVHeaddim_;
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile
...@@ -13,6 +13,7 @@ template <typename QDataType_, ...@@ -13,6 +13,7 @@ template <typename QDataType_,
typename SaccDataType_, typename SaccDataType_,
typename SMPLComputeDataType_, typename SMPLComputeDataType_,
typename BiasDataType_, typename BiasDataType_,
typename RandValOutputDataType_,
typename LSEDataType_, typename LSEDataType_,
typename PDataType_, typename PDataType_,
typename OaccDataType_, typename OaccDataType_,
...@@ -23,19 +24,20 @@ template <typename QDataType_, ...@@ -23,19 +24,20 @@ template <typename QDataType_,
typename Traits_> typename Traits_>
struct BlockFmhaPipelineProblem struct BlockFmhaPipelineProblem
{ {
using QDataType = remove_cvref_t<QDataType_>; using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>; using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>; using VDataType = remove_cvref_t<VDataType_>;
using SaccDataType = remove_cvref_t<SaccDataType_>; using SaccDataType = remove_cvref_t<SaccDataType_>;
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>; using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
using BiasDataType = remove_cvref_t<BiasDataType_>; using BiasDataType = remove_cvref_t<BiasDataType_>;
using LSEDataType = remove_cvref_t<LSEDataType_>; using RandValOutputDataType = remove_cvref_t<RandValOutputDataType_>;
using PDataType = remove_cvref_t<PDataType_>; using LSEDataType = remove_cvref_t<LSEDataType_>;
using OaccDataType = remove_cvref_t<OaccDataType_>; using PDataType = remove_cvref_t<PDataType_>;
using ODataType = remove_cvref_t<ODataType_>; using OaccDataType = remove_cvref_t<OaccDataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>; using ODataType = remove_cvref_t<ODataType_>;
using FmhaMask = remove_cvref_t<FmhaMask_>; using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using Traits = remove_cvref_t<Traits_>; using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
...@@ -47,6 +49,7 @@ struct BlockFmhaPipelineProblem ...@@ -47,6 +49,7 @@ struct BlockFmhaPipelineProblem
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -14,19 +15,20 @@ namespace ck_tile { ...@@ -14,19 +15,20 @@ namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy> template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
struct BlockFmhaPipelineQRKSVS struct BlockFmhaPipelineQRKSVS
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>; using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>; using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>; using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>; using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>; using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using PDataType = remove_cvref_t<typename Problem::PDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
...@@ -49,6 +51,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -49,6 +51,7 @@ struct BlockFmhaPipelineQRKSVS
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this // ... together with tensor distribution. tensor dist should able to overwrite this
...@@ -106,6 +109,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -106,6 +109,7 @@ struct BlockFmhaPipelineQRKSVS
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename QElementFunction, typename QElementFunction,
typename KElementFunction, typename KElementFunction,
...@@ -125,6 +129,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -125,6 +129,7 @@ struct BlockFmhaPipelineQRKSVS
const VElementFunction& v_element_func, const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func, const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func, const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func, const SAccElementFunction& s_acc_element_func,
...@@ -133,7 +138,8 @@ struct BlockFmhaPipelineQRKSVS ...@@ -133,7 +138,8 @@ struct BlockFmhaPipelineQRKSVS
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr) const void* smem_ptr,
BlockDropout& dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
...@@ -240,6 +246,9 @@ struct BlockFmhaPipelineQRKSVS ...@@ -240,6 +246,9 @@ struct BlockFmhaPipelineQRKSVS
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window = auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(), v_dram_block_window_tmp.get_window_lengths(),
...@@ -475,6 +484,12 @@ struct BlockFmhaPipelineQRKSVS ...@@ -475,6 +484,12 @@ struct BlockFmhaPipelineQRKSVS
}); });
}); });
if constexpr(kHasDropout)
{
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
}
block_sync_lds(); block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
...@@ -589,6 +604,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -589,6 +604,7 @@ struct BlockFmhaPipelineQRKSVS
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
...@@ -596,11 +612,13 @@ struct BlockFmhaPipelineQRKSVS ...@@ -596,11 +612,13 @@ struct BlockFmhaPipelineQRKSVS
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr) const void* smem_ptr,
BlockDropout& dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
...@@ -610,6 +628,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -610,6 +628,7 @@ struct BlockFmhaPipelineQRKSVS
identity{}, identity{},
bias_dram_block_window_tmp, bias_dram_block_window_tmp,
identity{}, identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp, lse_dram_block_window_tmp,
identity{}, identity{},
identity{}, identity{},
...@@ -618,7 +637,8 @@ struct BlockFmhaPipelineQRKSVS ...@@ -618,7 +637,8 @@ struct BlockFmhaPipelineQRKSVS
mask, mask,
position_encoding, position_encoding,
scale_s, scale_s,
smem_ptr); smem_ptr,
dropout);
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -15,19 +16,20 @@ namespace ck_tile { ...@@ -15,19 +16,20 @@ namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy> template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
struct BlockFmhaPipelineQRKSVSAsync struct BlockFmhaPipelineQRKSVSAsync
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>; using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>; using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>; using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>; using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>; using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using PDataType = remove_cvref_t<typename Problem::PDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
...@@ -54,6 +56,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -54,6 +56,7 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this // ... together with tensor distribution. tensor dist should able to overwrite this
...@@ -118,6 +121,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -118,6 +121,7 @@ struct BlockFmhaPipelineQRKSVSAsync
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename QElementFunction, typename QElementFunction,
typename KElementFunction, typename KElementFunction,
...@@ -137,6 +141,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -137,6 +141,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const VElementFunction& v_element_func, const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func, const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func, const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func, const SAccElementFunction& s_acc_element_func,
...@@ -145,7 +150,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -145,7 +150,8 @@ struct BlockFmhaPipelineQRKSVSAsync
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr) const void* smem_ptr,
BlockDropout& dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
...@@ -292,6 +298,9 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -292,6 +298,9 @@ struct BlockFmhaPipelineQRKSVSAsync
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window = auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(), v_dram_block_window_tmp.get_window_lengths(),
...@@ -558,6 +567,17 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -558,6 +567,17 @@ struct BlockFmhaPipelineQRKSVSAsync
}); });
}); });
if constexpr(kHasDropout)
{
auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
}
const auto p = const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)); cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
...@@ -688,6 +708,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -688,6 +708,7 @@ struct BlockFmhaPipelineQRKSVSAsync
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
...@@ -695,11 +716,13 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -695,11 +716,13 @@ struct BlockFmhaPipelineQRKSVSAsync
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr) const void* smem_ptr,
BlockDropout& dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
...@@ -709,6 +732,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -709,6 +732,7 @@ struct BlockFmhaPipelineQRKSVSAsync
identity{}, identity{},
bias_dram_block_window_tmp, bias_dram_block_window_tmp,
identity{}, identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp, lse_dram_block_window_tmp,
identity{}, identity{},
identity{}, identity{},
...@@ -717,7 +741,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -717,7 +741,8 @@ struct BlockFmhaPipelineQRKSVSAsync
mask, mask,
position_encoding, position_encoding,
scale_s, scale_s,
smem_ptr); smem_ptr,
dropout);
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,19 +14,20 @@ namespace ck_tile { ...@@ -14,19 +14,20 @@ namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy> template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>; using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>; using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>; using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>; using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>; using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using PDataType = remove_cvref_t<typename Problem::PDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
...@@ -49,6 +50,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -49,6 +50,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this // ... together with tensor distribution. tensor dist should able to overwrite this
...@@ -106,20 +108,23 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -106,20 +108,23 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported RandValDramBlockWindowTmp& /*randval_dram_block_window_tmp*/, // not supported
LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported
FmhaMask mask, FmhaMask mask,
PositionEncoding /*position_encoding*/, PositionEncoding /*position_encoding*/,
float scale_s, float scale_s,
float descale_qk, float descale_qk,
float descale_sv, float descale_sv,
void* smem_ptr) const void* smem_ptr,
BlockDropout& /*dropout*/) const // not supported
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -13,19 +13,20 @@ namespace ck_tile { ...@@ -13,19 +13,20 @@ namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQSKSVSDefaultPolicy> template <typename Problem_, typename Policy_ = BlockFmhaPipelineQSKSVSDefaultPolicy>
struct BlockFmhaPipelineQSKSVS struct BlockFmhaPipelineQSKSVS
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>; using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>; using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>; using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>; using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>; using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using PDataType = remove_cvref_t<typename Problem::PDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -89,13 +89,13 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -89,13 +89,13 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
std::is_same_v<typename Problem::KDataType, half_t> && std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>) std::is_same_v<typename Problem::SaccDataType, float>)
{ {
return WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
} }
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> && else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> && std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>) std::is_same_v<typename Problem::SaccDataType, float>)
{ {
return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
} }
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> && else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> && std::is_same_v<typename Problem::KDataType, fp8_t> &&
...@@ -212,13 +212,13 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -212,13 +212,13 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
std::is_same_v<typename Problem::KDataType, half_t> && std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>) std::is_same_v<typename Problem::SaccDataType, float>)
{ {
return WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
} }
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> && else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> && std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>) std::is_same_v<typename Problem::SaccDataType, float>)
{ {
return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
} }
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> && else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> && std::is_same_v<typename Problem::KDataType, fp8_t> &&
...@@ -691,7 +691,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -691,7 +691,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
{ {
// TODO: assume Q is in register // TODO: assume Q is in register
// TODO: assume K/V has same data type // TODO: assume K/V has same data type
...@@ -702,6 +702,40 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -702,6 +702,40 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
single_smem_size * max(NumPrefetchK, NumPrefetchV); single_smem_size * max(NumPrefetchK, NumPrefetchV);
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
if constexpr(AsyncCopyK)
{
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>();
}
else
{
return ck_tile::max(GetSmemSizeKV<Problem>(), GetSmemSizeDropout<Problem>());
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout()
{
if constexpr(Problem::kHasDropout)
{
constexpr auto gemm_0 = QXPolicy::template GetQKBlockGemm<Problem>();
constexpr auto config =
decltype(gemm_0)::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = WG::kN;
return (kMPerStep + 1) * kNPerStep * sizeof(uint8_t);
}
else
{
return 0;
}
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
{ {
......
...@@ -43,4 +43,53 @@ struct TileFmhaShape ...@@ -43,4 +43,53 @@ struct TileFmhaShape
ck_tile::tensor_layout::gemm::ColumnMajor>; ck_tile::tensor_layout::gemm::ColumnMajor>;
}; };
template <typename BlockTile_, // sequence<...
typename Gemm0BlockWarps_,
typename Gemm0WarpTile_,
typename Gemm1BlockWarps_,
typename Gemm1WarpTile_,
typename Gemm2BlockWarps_,
typename Gemm2WarpTile_,
typename Gemm3BlockWarps_,
typename Gemm3WarpTile_,
typename Gemm4BlockWarps_,
typename Gemm4WarpTile_>
struct TileFmhaBwdShape
{
using BlockTile = remove_cvref_t<BlockTile_>;
using Gemm0BlockWarps = remove_cvref_t<Gemm0BlockWarps_>;
using Gemm0WarpTile = remove_cvref_t<Gemm0WarpTile_>;
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
using Gemm2BlockWarps = remove_cvref_t<Gemm2BlockWarps_>;
using Gemm2WarpTile = remove_cvref_t<Gemm2WarpTile_>;
using Gemm3BlockWarps = remove_cvref_t<Gemm3BlockWarps_>;
using Gemm3WarpTile = remove_cvref_t<Gemm3WarpTile_>;
using Gemm4BlockWarps = remove_cvref_t<Gemm4BlockWarps_>;
using Gemm4WarpTile = remove_cvref_t<Gemm4WarpTile_>;
static constexpr index_t NumWarps =
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}) &&
NumWarps == reduce_on_sequence(Gemm4BlockWarps{}, multiplies{}, number<1>{}));
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
static constexpr index_t kK0 =
BlockTile::at(number<2>{}); // tile size along gemm0(Q@K^T) unroll
static constexpr index_t kK1 =
BlockTile::at(number<3>{}); // tile size along gemm1(P^T@dO) unroll
static constexpr index_t kK2 =
BlockTile::at(number<4>{}); // tile size along gemm2(dO@V^T) unroll
static constexpr index_t kK3 =
BlockTile::at(number<5>{}); // tile size along gemm3(dS^T@Q) unroll
static constexpr index_t kK4 = BlockTile::at(number<6>{}); // tile size along gemm4(dS@K) unroll
static constexpr index_t kQKHeaddim =
BlockTile::at(number<7>{}); // Q & K headdim, used for pipeline that need load Q/Q^T or
// K/K^T at once
static constexpr index_t kVHeaddim = BlockTile::at(number<8>{}); // V headdim, used for pipeline
// that need load V at once
};
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -13,7 +13,9 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */, ...@@ -13,7 +13,9 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimQ_ /* paddding for hdim_q */, bool kPadHeadDimQ_ /* paddding for hdim_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */, bool kPadHeadDimV_ /* paddding for hdim_v */,
BlockAttentionBiasEnum BiasEnum_, BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
bool kStoreLSE_, bool kStoreLSE_,
bool kHasDropout_,
bool kDoFp8StaticQuant_, bool kDoFp8StaticQuant_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */> index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaTraits struct TileFmhaTraits
...@@ -23,9 +25,21 @@ struct TileFmhaTraits ...@@ -23,9 +25,21 @@ struct TileFmhaTraits
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_; static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr auto BiasEnum = BiasEnum_; static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kStoreLSE = kStoreLSE_; static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */,
index_t kBlockPerCu_ = 2 /* hint to occupancy */>
struct TileFmhaBwdOGradDotOTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
} // namespace ck_tile } // namespace ck_tile
...@@ -3,20 +3,21 @@ ...@@ -3,20 +3,21 @@
#pragma once #pragma once
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
// Problem Description for BlockGemmARegBGmemCReg
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_>
struct BlockGemmARegBGmemCRegProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -28,7 +28,7 @@ struct BlockGemmARegBGmemCRegV1 ...@@ -28,7 +28,7 @@ struct BlockGemmARegBGmemCRegV1
// use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation // use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation
using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1< using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1<
BlockGemmARegBSmemCRegProblem<ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape>, BlockGemmProblem<ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape>,
BlockGemmARegBSmemCRegV1DefaultPolicy>; BlockGemmARegBSmemCRegV1DefaultPolicy>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -35,13 +35,16 @@ struct BlockGemmARegBSmemCRegV1 ...@@ -35,13 +35,16 @@ struct BlockGemmARegBSmemCRegV1
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>, std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!"); "wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; // constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; // constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; // constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && // static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK, // KPerBlock == BlockGemmShape::kK,
"wrong!"); // "wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
...@@ -181,23 +184,10 @@ struct BlockGemmARegBSmemCRegV1 ...@@ -181,23 +184,10 @@ struct BlockGemmARegBSmemCRegV1
}); });
} }
// C = A * B CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{ {
static_assert( constexpr index_t MPerBlock = BlockGemmShape::kM;
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> && constexpr index_t NPerBlock = BlockGemmShape::kN;
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
...@@ -208,20 +198,7 @@ struct BlockGemmARegBSmemCRegV1 ...@@ -208,20 +198,7 @@ struct BlockGemmARegBSmemCRegV1
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK; // constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id() % NWarp;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>, sequence<>,
...@@ -231,108 +208,20 @@ struct BlockGemmARegBSmemCRegV1 ...@@ -231,108 +208,20 @@ struct BlockGemmARegBSmemCRegV1
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// constrcut from A-block-tensor from A-Block-tensor-tmp // C = A * B
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent template <typename ABlockTensorTmp, typename BBlockWindowTmp>
// distribution CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
auto a_block_tensor = const BBlockWindowTmp& b_block_window_tmp) const
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr); {
auto c_block_tensor = MakeCBlockTile();
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// Construct C-Block-HostTensor
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
return c_block_tensor; return c_block_tensor;
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment