Commit edb78a47 authored by Max Podkorytov's avatar Max Podkorytov
Browse files

clang-format and remove dead code

parent 60113859
...@@ -128,42 +128,39 @@ struct BlockFmhaPipelineQSKSVS ...@@ -128,42 +128,39 @@ struct BlockFmhaPipelineQSKSVS
typename OAccElementFunction, typename OAccElementFunction,
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
// operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp & q_dram_block_window_tmp, // M0*K0 tile
// const QElementFunction& q_element_func, const QElementFunction &
// const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile q_element_func,
// const KElementFunction& k_element_func, const KDramBlockWindowTmp &
// const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile k_dram_block_window_tmp, // N0*K0 tile
// const VElementFunction& v_element_func, const KElementFunction &
// const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile k_element_func,
// const BiasElementFunction& bias_element_func, const VDramBlockWindowTmp &
// LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile v_dram_block_window_tmp, // N1*K1 tile
// const LSEElementFunction& lse_element_func, const VElementFunction &
// const SAccElementFunction& s_acc_element_func, v_element_func,
// const PComputeElementFunction& p_compute_element_func, const BiasDramBlockWindowTmp &
// const OAccElementFunction& o_acc_element_func, bias_dram_block_window_tmp, // M0*N0 tile
// FmhaMask mask, const BiasElementFunction &
// PositionEncoding position_encoding, bias_element_func,
// float scale_s, RandValDramBlockWindowTmp &
// void* smem_ptr) const randval_dram_block_window_tmp,
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile LSEDramBlockWindowTmp &
const QElementFunction& q_element_func, lse_dram_window_tmp, // M0*1 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const LSEElementFunction &
const KElementFunction& k_element_func, lse_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const SAccElementFunction &
const VElementFunction& v_element_func, s_acc_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const PComputeElementFunction &
const BiasElementFunction& bias_element_func, p_compute_element_func,
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, const OAccElementFunction &
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile o_acc_element_func,
const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
DropoutType& dropout) const DropoutType &
dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
...@@ -263,12 +260,12 @@ struct BlockFmhaPipelineQSKSVS ...@@ -263,12 +260,12 @@ struct BlockFmhaPipelineQSKSVS
{seqlen_k_start, 0}); {seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window( auto bias_dram_window =
bias_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(), bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
// Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>()); // Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto v_dram_window = auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
...@@ -621,41 +618,6 @@ struct BlockFmhaPipelineQSKSVS ...@@ -621,41 +618,6 @@ struct BlockFmhaPipelineQSKSVS
return o_acc; return o_acc;
} }
// template <typename QDramBlockWindowTmp,
// typename KDramBlockWindowTmp,
// typename VDramBlockWindowTmp,
// typename BiasDramBlockWindowTmp,
// typename LSEDramBlockWindowTmp,
// typename PositionEncoding>
// CK_TILE_HOST_DEVICE auto
// operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
// const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
// const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
// const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
// LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
// FmhaMask mask,
// PositionEncoding position_encoding,
// float scale_s,
// void* smem_ptr) const
// {
// return operator()(q_dram_block_window_tmp,
// identity{},
// k_dram_block_window_tmp,
// identity{},
// v_dram_block_window_tmp,
// identity{},
// bias_dram_block_window_tmp,
// identity{},
// lse_dram_block_window_tmp,
// identity{},
// identity{},
// identity{},
// identity{},
// mask,
// position_encoding,
// scale_s,
// smem_ptr);
// }
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
......
...@@ -471,7 +471,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -471,7 +471,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem, index_t IBuf = 0> template <typename Problem, index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto CK_TILE_HOST_DEVICE static constexpr auto
MakeKLdsStoreBlockDescriptor(number<IBuf> = number<0>{}) MakeKLdsStoreBlockDescriptor(number<IBuf> = number<0>{})
{ {
// K is always k-major, we use async-copy to load into LDS // K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
...@@ -526,7 +526,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -526,7 +526,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM #if K_LDS_LOAD_USE_OFFSET_TRANSFORM
template <typename Problem, index_t IBuf = 0> template <typename Problem, index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto CK_TILE_HOST_DEVICE static constexpr auto
MakeKLdsLoadBlockDescriptor(number<IBuf> = number<0>{}) MakeKLdsLoadBlockDescriptor(number<IBuf> = number<0>{})
{ {
// K is always k-major, we use async-copy to load into LDS // K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment