Unverified Commit 2c8e04aa authored by Max Podkorytov's avatar Max Podkorytov
Browse files

clang-format

parent 503a7da6
...@@ -849,9 +849,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -849,9 +849,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
else // fmha_fwd_traits or fmha_splitkv_traits else // fmha_fwd_traits or fmha_splitkv_traits
{ {
traits.is_group_mode = (mode == mode_enum::group); traits.is_group_mode = (mode == mode_enum::group);
traits.mask_type = mask.type; traits.mask_type = mask.type;
traits.bias_type = bias.type; traits.bias_type = bias.type;
// traits.has_lse = lse; // traits.has_lse = lse;
// traits.do_fp8_static_quant = squant; // traits.do_fp8_static_quant = squant;
...@@ -1504,13 +1504,13 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1504,13 +1504,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask.type == mask_enum::mask_top_left)); mask.type == mask_enum::mask_top_left));
} }
auto pre_softmax = [] (auto s) { auto pre_softmax = [](auto s) {
//ck_tile::detail::swallow(s); // ck_tile::detail::swallow(s);
return CK_PRE_SOFTMAX_F; return CK_PRE_SOFTMAX_F;
}; };
s_host_ref.ForEach([&](auto& self, auto i) { s_host_ref.ForEach([&](auto& self, auto i) {
auto new_val = pre_softmax(self(i)); auto new_val = pre_softmax(self(i));
self(i) = new_val; self(i) = new_val;
}); });
if(lse) if(lse)
......
...@@ -20,7 +20,10 @@ ...@@ -20,7 +20,10 @@
namespace ck_tile { namespace ck_tile {
template <typename FmhaPipeline_, typename EpiloguePipeline_, typename ScoreModFunction_, typename PreSoftmaxFunction_> template <typename FmhaPipeline_,
typename EpiloguePipeline_,
typename ScoreModFunction_,
typename PreSoftmaxFunction_>
struct FmhaFwdKernel struct FmhaFwdKernel
{ {
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>; using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
...@@ -1317,8 +1320,7 @@ struct FmhaFwdKernel ...@@ -1317,8 +1320,7 @@ struct FmhaFwdKernel
}; };
auto pre_softmax_def = PreSoftmaxFunction_{}; auto pre_softmax_def = PreSoftmaxFunction_{};
auto pre_softmax_arg = [pre_softmax_def]( auto pre_softmax_arg = [pre_softmax_def](typename PreSoftmaxFunction_::TScore s) {
typename PreSoftmaxFunction_::TScore s) {
return pre_softmax_def(s); return pre_softmax_def(s);
}; };
......
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
namespace ck_tile { namespace ck_tile {
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) // a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy, typename PreSoftmaxFunction_> template <typename Problem_,
typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy,
typename PreSoftmaxFunction_>
struct BlockFmhaPipelineQRKSVSAsync struct BlockFmhaPipelineQRKSVSAsync
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment