Commit 74f1516c authored by danyao12's avatar danyao12
Browse files

tmp save

parent 497ccb87
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k & k^t located in lds.
using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ true,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k located in lds.
using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, q & k & do located in lds.
using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ true,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ true,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
...@@ -8,9 +8,7 @@ namespace ck_tile { ...@@ -8,9 +8,7 @@ namespace ck_tile {
// This class is used for codegen pattern matching // This class is used for codegen pattern matching
enum class BlockFmhaBwdPipelineEnum enum class BlockFmhaBwdPipelineEnum
{ {
KSKTSVR = 0, KRKTRVR = 0,
QSKSVROGradS,
KSVR,
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -24,7 +24,9 @@ template <typename QDataType_, ...@@ -24,7 +24,9 @@ template <typename QDataType_,
typename BiasGradDataType_, typename BiasGradDataType_,
typename BlockFmhaShape_, typename BlockFmhaShape_,
bool kIsGroupMode_, bool kIsGroupMode_,
bool kIsDeterministic_,
typename FmhaMask_, typename FmhaMask_,
typename FmhaDropout_,
typename Traits_> typename Traits_>
struct BlockFmhaBwdPipelineProblem struct BlockFmhaBwdPipelineProblem
{ {
...@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem ...@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem
using BiasGradDataType = remove_cvref_t<BiasGradDataType_>; using BiasGradDataType = remove_cvref_t<BiasGradDataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>; using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>; using FmhaMask = remove_cvref_t<FmhaMask_>;
using FmhaDropout = remove_cvref_t<FmhaDropout_>;
using Traits = remove_cvref_t<Traits_>; using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
// attributes from traits // attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
...@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem ...@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
...@@ -88,4 +91,30 @@ struct BlockFmhaBwdOGradDotOPipelineProblem ...@@ -88,4 +91,30 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
template <typename AccDataType_,
typename QGradDataType_,
typename Shape_,
typename Traits_,
bool kIsGroupMode_,
bool kIsDeterministic_>
struct BlockFmhaBwdConvertQGradPipelineProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using QGradDataType = remove_cvref_t<QGradDataType_>;
using Shape = remove_cvref_t<Shape_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = Shape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
static_assert(0 < kBlockSize && kBlockSize % get_warp_size() == 0,
"kBlockSize should be divisible by get_warp_size()");
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile } // namespace ck_tile
...@@ -28,6 +28,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -28,6 +28,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
using PDataType = remove_cvref_t<typename Problem::PDataType>; using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
...@@ -49,8 +50,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -49,8 +50,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = true; // always store LSE (acc) static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kHasDropout = false; // ignore this flag
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
...@@ -141,7 +141,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -141,7 +141,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
BlockDropout& dropout) const FmhaDropout dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
...@@ -249,7 +249,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -249,7 +249,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>( auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start); randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window = auto v_dram_window =
...@@ -501,10 +501,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -501,10 +501,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}); });
}); });
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>( dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); smem_ptr,
q_origin.at(number<0>{}),
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
} }
block_sync_lds(); block_sync_lds();
...@@ -637,7 +641,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -637,7 +641,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
BlockDropout& dropout) const FmhaDropout dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
......
...@@ -29,6 +29,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -29,6 +29,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
using PDataType = remove_cvref_t<typename Problem::PDataType>; using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
...@@ -54,8 +55,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -54,8 +55,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = true; // always store LSE (acc) static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kHasDropout = false; // ignore this flag
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
...@@ -153,7 +153,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -153,7 +153,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
BlockDropout& dropout) const FmhaDropout dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
...@@ -301,7 +301,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -301,7 +301,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>( auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start); randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window = auto v_dram_window =
...@@ -584,12 +584,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -584,12 +584,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
}); });
}); });
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
auto randval_ptr = auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>(); reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>( dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr, randval_ptr,
q_origin.at(number<0>{}),
seqlen_k_start + i_total_loops * kN0, seqlen_k_start + i_total_loops * kN0,
p_compute, p_compute,
randval_dram_window); randval_dram_window);
...@@ -741,7 +742,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -741,7 +742,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
BlockDropout& dropout) const FmhaDropout dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
......
...@@ -21,6 +21,7 @@ template <typename QDataType_, ...@@ -21,6 +21,7 @@ template <typename QDataType_,
typename BlockFmhaShape_, typename BlockFmhaShape_,
bool kIsGroupMode_, bool kIsGroupMode_,
typename FmhaMask_, typename FmhaMask_,
typename FmhaDropout_,
typename Traits_> typename Traits_>
struct BlockFmhaPipelineProblem struct BlockFmhaPipelineProblem
{ {
...@@ -37,6 +38,7 @@ struct BlockFmhaPipelineProblem ...@@ -37,6 +38,7 @@ struct BlockFmhaPipelineProblem
using ODataType = remove_cvref_t<ODataType_>; using ODataType = remove_cvref_t<ODataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>; using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>; using FmhaMask = remove_cvref_t<FmhaMask_>;
using FmhaDropout = remove_cvref_t<FmhaDropout_>;
using Traits = remove_cvref_t<Traits_>; using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
...@@ -49,7 +51,6 @@ struct BlockFmhaPipelineProblem ...@@ -49,7 +51,6 @@ struct BlockFmhaPipelineProblem
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
...@@ -68,6 +69,7 @@ template <typename QDataType, ...@@ -68,6 +69,7 @@ template <typename QDataType,
typename BlockFmhaShape, typename BlockFmhaShape,
bool kIsGroupMode, bool kIsGroupMode,
typename FmhaMask, typename FmhaMask,
typename FmhaDropout,
typename Traits> typename Traits>
struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType, struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType,
KDataType, KDataType,
...@@ -83,6 +85,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType, ...@@ -83,6 +85,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType,
BlockFmhaShape, BlockFmhaShape,
kIsGroupMode, kIsGroupMode,
FmhaMask, FmhaMask,
FmhaDropout,
Traits> Traits>
{ {
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
......
...@@ -29,6 +29,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -29,6 +29,7 @@ struct BlockFmhaPipelineQRKSVS
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
...@@ -51,7 +52,6 @@ struct BlockFmhaPipelineQRKSVS ...@@ -51,7 +52,6 @@ struct BlockFmhaPipelineQRKSVS
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this // ... together with tensor distribution. tensor dist should able to overwrite this
...@@ -100,8 +100,6 @@ struct BlockFmhaPipelineQRKSVS ...@@ -100,8 +100,6 @@ struct BlockFmhaPipelineQRKSVS
static constexpr const char* name = "qr"; static constexpr const char* name = "qr";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
...@@ -141,7 +139,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -141,7 +139,7 @@ struct BlockFmhaPipelineQRKSVS
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
DropoutType& dropout) const FmhaDropout dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
...@@ -486,10 +484,14 @@ struct BlockFmhaPipelineQRKSVS ...@@ -486,10 +484,14 @@ struct BlockFmhaPipelineQRKSVS
}); });
}); });
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>( dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); smem_ptr,
q_origin.at(number<0>{}),
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
} }
block_sync_lds(); block_sync_lds();
...@@ -620,7 +622,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -620,7 +622,7 @@ struct BlockFmhaPipelineQRKSVS
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
DropoutType& dropout) const FmhaDropout dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
......
...@@ -28,6 +28,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -28,6 +28,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
...@@ -124,7 +125,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -124,7 +125,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
float descale_qk, float descale_qk,
float descale_sv, float descale_sv,
void* smem_ptr, void* smem_ptr,
BlockDropout& /*dropout*/) const // not supported FmhaDropout& /*dropout*/) const // not supported
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment