Commit 74f1516c authored by danyao12's avatar danyao12
Browse files

tmp save

parent 497ccb87
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k & k^t located in lds.
using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ true,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k located in lds.
using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, q & k & do located in lds.
using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ true,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ true,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
......@@ -8,9 +8,7 @@ namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockFmhaBwdPipelineEnum
{
KSKTSVR = 0,
QSKSVROGradS,
KSVR,
KRKTRVR = 0,
};
} // namespace ck_tile
......@@ -24,7 +24,9 @@ template <typename QDataType_,
typename BiasGradDataType_,
typename BlockFmhaShape_,
bool kIsGroupMode_,
bool kIsDeterministic_,
typename FmhaMask_,
typename FmhaDropout_,
typename Traits_>
struct BlockFmhaBwdPipelineProblem
{
......@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem
using BiasGradDataType = remove_cvref_t<BiasGradDataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
using FmhaDropout = remove_cvref_t<FmhaDropout_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
......@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
......@@ -88,4 +91,30 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
template <typename AccDataType_,
typename QGradDataType_,
typename Shape_,
typename Traits_,
bool kIsGroupMode_,
bool kIsDeterministic_>
struct BlockFmhaBwdConvertQGradPipelineProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using QGradDataType = remove_cvref_t<QGradDataType_>;
using Shape = remove_cvref_t<Shape_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = Shape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
static_assert(0 < kBlockSize && kBlockSize % get_warp_size() == 0,
"kBlockSize should be divisible by get_warp_size()");
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile
......@@ -28,6 +28,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
......@@ -49,8 +50,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kHasDropout = false; // ignore this flag
static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
......@@ -141,7 +141,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
FmhaDropout dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
......@@ -249,7 +249,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>(
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window =
......@@ -501,10 +501,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
});
});
if constexpr(kHasDropout)
if constexpr(FmhaDropout::IsDropout)
{
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr,
q_origin.at(number<0>{}),
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
}
block_sync_lds();
......@@ -637,7 +641,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
FmhaDropout dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
......
......@@ -29,6 +29,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
......@@ -54,8 +55,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kHasDropout = false; // ignore this flag
static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
......@@ -153,7 +153,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
FmhaDropout dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
......@@ -301,7 +301,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>(
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window =
......@@ -584,12 +584,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
});
});
if constexpr(kHasDropout)
if constexpr(FmhaDropout::IsDropout)
{
auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
q_origin.at(number<0>{}),
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
......@@ -741,7 +742,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
FmhaDropout dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
......
......@@ -21,6 +21,7 @@ template <typename QDataType_,
typename BlockFmhaShape_,
bool kIsGroupMode_,
typename FmhaMask_,
typename FmhaDropout_,
typename Traits_>
struct BlockFmhaPipelineProblem
{
......@@ -37,6 +38,7 @@ struct BlockFmhaPipelineProblem
using ODataType = remove_cvref_t<ODataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
using FmhaDropout = remove_cvref_t<FmhaDropout_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
......@@ -49,7 +51,6 @@ struct BlockFmhaPipelineProblem
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
......@@ -68,6 +69,7 @@ template <typename QDataType,
typename BlockFmhaShape,
bool kIsGroupMode,
typename FmhaMask,
typename FmhaDropout,
typename Traits>
struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType,
KDataType,
......@@ -83,6 +85,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType,
BlockFmhaShape,
kIsGroupMode,
FmhaMask,
FmhaDropout,
Traits>
{
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
......
......@@ -29,6 +29,7 @@ struct BlockFmhaPipelineQRKSVS
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
......@@ -51,7 +52,6 @@ struct BlockFmhaPipelineQRKSVS
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
......@@ -100,8 +100,6 @@ struct BlockFmhaPipelineQRKSVS
static constexpr const char* name = "qr";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
......@@ -141,7 +139,7 @@ struct BlockFmhaPipelineQRKSVS
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
DropoutType& dropout) const
FmhaDropout dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
......@@ -486,10 +484,14 @@ struct BlockFmhaPipelineQRKSVS
});
});
if constexpr(kHasDropout)
if constexpr(FmhaDropout::IsDropout)
{
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
smem_ptr,
q_origin.at(number<0>{}),
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
}
block_sync_lds();
......@@ -620,7 +622,7 @@ struct BlockFmhaPipelineQRKSVS
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
DropoutType& dropout) const
FmhaDropout dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
......
......@@ -28,6 +28,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
......@@ -124,7 +125,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
float descale_qk,
float descale_sv,
void* smem_ptr,
BlockDropout& /*dropout*/) const // not supported
FmhaDropout& /*dropout*/) const // not supported
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment