"docker/vscode:/vscode.git/clone" did not exist on "dde5cf5d02188c7f493caef32d4ed59fc1805a2c"
Commit 1862b27f authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Revert "update qsksvs pipeline"

This reverts commit bfc997a7.
parent f42beae8
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -100,7 +99,8 @@ struct BlockFmhaPipelineQSKSVS ...@@ -100,7 +99,8 @@ struct BlockFmhaPipelineQSKSVS
static constexpr const char* name = "qs"; static constexpr const char* name = "qs";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>; // using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
using DropoutType = int32_t; // unused
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
...@@ -267,8 +267,7 @@ struct BlockFmhaPipelineQSKSVS ...@@ -267,8 +267,7 @@ struct BlockFmhaPipelineQSKSVS
bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(), bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
// Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto v_dram_window = auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
...@@ -621,46 +620,10 @@ struct BlockFmhaPipelineQSKSVS ...@@ -621,46 +620,10 @@ struct BlockFmhaPipelineQSKSVS
return o_acc; return o_acc;
} }
// template <typename QDramBlockWindowTmp,
// typename KDramBlockWindowTmp,
// typename VDramBlockWindowTmp,
// typename BiasDramBlockWindowTmp,
// typename LSEDramBlockWindowTmp,
// typename PositionEncoding>
// CK_TILE_HOST_DEVICE auto
// operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
// const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
// const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
// const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
// LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
// FmhaMask mask,
// PositionEncoding position_encoding,
// float scale_s,
// void* smem_ptr) const
// {
// return operator()(q_dram_block_window_tmp,
// identity{},
// k_dram_block_window_tmp,
// identity{},
// v_dram_block_window_tmp,
// identity{},
// bias_dram_block_window_tmp,
// identity{},
// lse_dram_block_window_tmp,
// identity{},
// identity{},
// identity{},
// identity{},
// mask,
// position_encoding,
// scale_s,
// smem_ptr);
// }
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
...@@ -668,13 +631,11 @@ struct BlockFmhaPipelineQSKSVS ...@@ -668,13 +631,11 @@ struct BlockFmhaPipelineQSKSVS
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr) const
DropoutType& dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
...@@ -684,7 +645,6 @@ struct BlockFmhaPipelineQSKSVS ...@@ -684,7 +645,6 @@ struct BlockFmhaPipelineQSKSVS
identity{}, identity{},
bias_dram_block_window_tmp, bias_dram_block_window_tmp,
identity{}, identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp, lse_dram_block_window_tmp,
identity{}, identity{},
identity{}, identity{},
...@@ -693,8 +653,7 @@ struct BlockFmhaPipelineQSKSVS ...@@ -693,8 +653,7 @@ struct BlockFmhaPipelineQSKSVS
mask, mask,
position_encoding, position_encoding,
scale_s, scale_s,
smem_ptr, smem_ptr);
dropout);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment