Commit 1862b27f authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Revert "update qsksvs pipeline"

This reverts commit bfc997a7.
parent f42beae8
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -100,7 +99,8 @@ struct BlockFmhaPipelineQSKSVS ...@@ -100,7 +99,8 @@ struct BlockFmhaPipelineQSKSVS
static constexpr const char* name = "qs"; static constexpr const char* name = "qs";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>; // using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
using DropoutType = int32_t; // unused
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
...@@ -267,8 +267,7 @@ struct BlockFmhaPipelineQSKSVS ...@@ -267,8 +267,7 @@ struct BlockFmhaPipelineQSKSVS
bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(), bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
// Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto v_dram_window = auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
...@@ -621,46 +620,10 @@ struct BlockFmhaPipelineQSKSVS ...@@ -621,46 +620,10 @@ struct BlockFmhaPipelineQSKSVS
return o_acc; return o_acc;
} }
// template <typename QDramBlockWindowTmp,
// typename KDramBlockWindowTmp,
// typename VDramBlockWindowTmp,
// typename BiasDramBlockWindowTmp,
// typename LSEDramBlockWindowTmp,
// typename PositionEncoding>
// CK_TILE_HOST_DEVICE auto
// operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
// const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
// const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
// const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
// LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
// FmhaMask mask,
// PositionEncoding position_encoding,
// float scale_s,
// void* smem_ptr) const
// {
// return operator()(q_dram_block_window_tmp,
// identity{},
// k_dram_block_window_tmp,
// identity{},
// v_dram_block_window_tmp,
// identity{},
// bias_dram_block_window_tmp,
// identity{},
// lse_dram_block_window_tmp,
// identity{},
// identity{},
// identity{},
// identity{},
// mask,
// position_encoding,
// scale_s,
// smem_ptr);
// }
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
...@@ -668,13 +631,11 @@ struct BlockFmhaPipelineQSKSVS ...@@ -668,13 +631,11 @@ struct BlockFmhaPipelineQSKSVS
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr) const
DropoutType& dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
...@@ -684,7 +645,6 @@ struct BlockFmhaPipelineQSKSVS ...@@ -684,7 +645,6 @@ struct BlockFmhaPipelineQSKSVS
identity{}, identity{},
bias_dram_block_window_tmp, bias_dram_block_window_tmp,
identity{}, identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp, lse_dram_block_window_tmp,
identity{}, identity{},
identity{}, identity{},
...@@ -693,8 +653,7 @@ struct BlockFmhaPipelineQSKSVS ...@@ -693,8 +653,7 @@ struct BlockFmhaPipelineQSKSVS
mask, mask,
position_encoding, position_encoding,
scale_s, scale_s,
smem_ptr, smem_ptr);
dropout);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment