Commit 74f1516c authored by danyao12's avatar danyao12
Browse files

tmp save

parent 497ccb87
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k & k^t located in lds.
using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ true,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k located in lds.
using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, q & k & do located in lds.
using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ true,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ true,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
......@@ -8,9 +8,7 @@ namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockFmhaBwdPipelineEnum
{
KSKTSVR = 0,
QSKSVROGradS,
KSVR,
KRKTRVR = 0,
};
} // namespace ck_tile
......@@ -24,7 +24,9 @@ template <typename QDataType_,
typename BiasGradDataType_,
typename BlockFmhaShape_,
bool kIsGroupMode_,
bool kIsDeterministic_,
typename FmhaMask_,
typename FmhaDropout_,
typename Traits_>
struct BlockFmhaBwdPipelineProblem
{
......@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem
using BiasGradDataType = remove_cvref_t<BiasGradDataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
using FmhaDropout = remove_cvref_t<FmhaDropout_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
......@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
......@@ -88,4 +91,30 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
template <typename AccDataType_,
typename QGradDataType_,
typename Shape_,
typename Traits_,
bool kIsGroupMode_,
bool kIsDeterministic_>
struct BlockFmhaBwdConvertQGradPipelineProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using QGradDataType = remove_cvref_t<QGradDataType_>;
using Shape = remove_cvref_t<Shape_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = Shape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
static_assert(0 < kBlockSize && kBlockSize % get_warp_size() == 0,
"kBlockSize should be divisible by get_warp_size()");
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile
......@@ -28,6 +28,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
......@@ -50,7 +51,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kHasDropout = false; // ignore this flag
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
......@@ -141,7 +141,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
FmhaDropout dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
......@@ -249,7 +249,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>(
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window =
......@@ -501,10 +501,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
});
});
if constexpr(kHasDropout)
if constexpr(FmhaDropout::IsDropout)
{
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr,
q_origin.at(number<0>{}),
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
}
block_sync_lds();
......@@ -637,7 +641,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
FmhaDropout dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
......
......@@ -29,6 +29,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
......@@ -55,7 +56,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kHasDropout = false; // ignore this flag
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
......@@ -153,7 +153,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
FmhaDropout dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
......@@ -301,7 +301,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>(
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window =
......@@ -584,12 +584,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
});
});
if constexpr(kHasDropout)
if constexpr(FmhaDropout::IsDropout)
{
auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
q_origin.at(number<0>{}),
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
......@@ -741,7 +742,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
FmhaDropout dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
......
......@@ -21,6 +21,7 @@ template <typename QDataType_,
typename BlockFmhaShape_,
bool kIsGroupMode_,
typename FmhaMask_,
typename FmhaDropout_,
typename Traits_>
struct BlockFmhaPipelineProblem
{
......@@ -37,6 +38,7 @@ struct BlockFmhaPipelineProblem
using ODataType = remove_cvref_t<ODataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
using FmhaDropout = remove_cvref_t<FmhaDropout_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
......@@ -49,7 +51,6 @@ struct BlockFmhaPipelineProblem
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
......@@ -68,6 +69,7 @@ template <typename QDataType,
typename BlockFmhaShape,
bool kIsGroupMode,
typename FmhaMask,
typename FmhaDropout,
typename Traits>
struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType,
KDataType,
......@@ -83,6 +85,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType,
BlockFmhaShape,
kIsGroupMode,
FmhaMask,
FmhaDropout,
Traits>
{
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
......
......@@ -29,6 +29,7 @@ struct BlockFmhaPipelineQRKSVS
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
......@@ -51,7 +52,6 @@ struct BlockFmhaPipelineQRKSVS
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
......@@ -100,8 +100,6 @@ struct BlockFmhaPipelineQRKSVS
static constexpr const char* name = "qr";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
......@@ -141,7 +139,7 @@ struct BlockFmhaPipelineQRKSVS
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
DropoutType& dropout) const
FmhaDropout dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
......@@ -486,10 +484,14 @@ struct BlockFmhaPipelineQRKSVS
});
});
if constexpr(kHasDropout)
if constexpr(FmhaDropout::IsDropout)
{
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
smem_ptr,
q_origin.at(number<0>{}),
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
}
block_sync_lds();
......@@ -620,7 +622,7 @@ struct BlockFmhaPipelineQRKSVS
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
DropoutType& dropout) const
FmhaDropout dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
......
......@@ -30,6 +30,7 @@ struct BlockFmhaPipelineQRKSVSAsync
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
......@@ -56,7 +57,6 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
......@@ -112,8 +112,6 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr const char* name = "qr_async";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
......@@ -153,7 +151,7 @@ struct BlockFmhaPipelineQRKSVSAsync
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
DropoutType& dropout) const
FmhaDropout dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
......@@ -569,12 +567,13 @@ struct BlockFmhaPipelineQRKSVSAsync
});
});
if constexpr(kHasDropout)
if constexpr(FmhaDropout::IsDropout)
{
auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
q_origin.at(number<0>{}),
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
......@@ -730,7 +729,7 @@ struct BlockFmhaPipelineQRKSVSAsync
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
DropoutType& dropout) const
FmhaDropout dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
......
......@@ -28,6 +28,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
......@@ -124,7 +125,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
float descale_qk,
float descale_sv,
void* smem_ptr,
BlockDropout& /*dropout*/) const // not supported
FmhaDropout& /*dropout*/) const // not supported
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
......
......@@ -92,4 +92,20 @@ struct TileFmhaBwdShape
// that need load V at once
};
template <typename BlockTile_, // sequence<...
typename BlockWarps_,
typename WarpTile_>
struct TileFmhaBwdConvertQGradShape
{
using BlockTile = remove_cvref_t<BlockTile_>;
using BlockWarps = remove_cvref_t<BlockWarps_>;
using WarpTile = remove_cvref_t<WarpTile_>;
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
static constexpr index_t kQKHeaddim = BlockTile::at(number<2>{}); // Q & K headdim
};
} // namespace ck_tile
......@@ -15,7 +15,6 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
bool kStoreLSE_,
bool kHasDropout_,
bool kDoFp8StaticQuant_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaTraits
......@@ -27,7 +26,6 @@ struct TileFmhaTraits
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
......@@ -39,7 +37,6 @@ template <bool kPadSeqLenQ /* padding for seqlen_q */,
BlockAttentionBiasEnum BiasEnum,
bool kHasBiasGrad,
bool kStoreLSE,
bool kHasDropout,
bool kDoFp8StaticQuant,
bool kHasUnevenSplits_ = true,
index_t kBlockPerCu = -1 /* overwrite occupancy if not -1 */>
......@@ -50,7 +47,6 @@ struct TileFmhaFwdSplitKVTraits : TileFmhaTraits<kPadSeqLenQ,
BiasEnum,
kHasBiasGrad,
kStoreLSE,
kHasDropout,
kDoFp8StaticQuant,
kBlockPerCu>
{
......@@ -86,4 +82,14 @@ struct TileFmhaBwdOGradDotOTraits
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
index_t kBlockPerCu_ = 2 /* hint to occupancy */>
struct TileFmhaBwdConvertQGradTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
} // namespace ck_tile
......@@ -5,6 +5,9 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment