"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "1f4966dcfade7d7e1b95ba766cf0b7509db31689"
Commit 1862b27f authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Revert "update qsksvs pipeline"

This reverts commit bfc997a7.
parent f42beae8
......@@ -5,7 +5,6 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
namespace ck_tile {
......@@ -100,7 +99,8 @@ struct BlockFmhaPipelineQSKSVS
static constexpr const char* name = "qs";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
// using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
using DropoutType = int32_t; // unused
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
......@@ -267,8 +267,7 @@ struct BlockFmhaPipelineQSKSVS
bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
// Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
......@@ -621,46 +620,10 @@ struct BlockFmhaPipelineQSKSVS
return o_acc;
}
// template <typename QDramBlockWindowTmp,
// typename KDramBlockWindowTmp,
// typename VDramBlockWindowTmp,
// typename BiasDramBlockWindowTmp,
// typename LSEDramBlockWindowTmp,
// typename PositionEncoding>
// CK_TILE_HOST_DEVICE auto
// operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
// const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
// const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
// const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
// LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
// FmhaMask mask,
// PositionEncoding position_encoding,
// float scale_s,
// void* smem_ptr) const
// {
// return operator()(q_dram_block_window_tmp,
// identity{},
// k_dram_block_window_tmp,
// identity{},
// v_dram_block_window_tmp,
// identity{},
// bias_dram_block_window_tmp,
// identity{},
// lse_dram_block_window_tmp,
// identity{},
// identity{},
// identity{},
// identity{},
// mask,
// position_encoding,
// scale_s,
// smem_ptr);
// }
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
......@@ -668,13 +631,11 @@ struct BlockFmhaPipelineQSKSVS
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
DropoutType& dropout) const
void* smem_ptr) const
{
return operator()(q_dram_block_window_tmp,
identity{},
......@@ -684,7 +645,6 @@ struct BlockFmhaPipelineQSKSVS
identity{},
bias_dram_block_window_tmp,
identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp,
identity{},
identity{},
......@@ -693,8 +653,7 @@ struct BlockFmhaPipelineQSKSVS
mask,
position_encoding,
scale_s,
smem_ptr,
dropout);
smem_ptr);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment