Commit 260ace4b authored by danyao12's avatar danyao12
Browse files

code cleanup

parent 9b4b4622
...@@ -58,43 +58,33 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_> ...@@ -58,43 +58,33 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start) index_t seqlen_qk_start)
{ {
if constexpr(IsDropout) constexpr auto config =
{ BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
constexpr auto config = using WG = remove_cvref_t<decltype(config.template at<0>())>;
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>(); constexpr index_t MWarp = config.template at<1>();
using WG = remove_cvref_t<decltype(config.template at<0>())>; constexpr index_t NWarp = config.template at<2>();
constexpr index_t MWarp = config.template at<1>(); constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t NWarp = config.template at<2>(); constexpr index_t kNPerStep = NWarp * WG::kN;
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
auto randval_dram_window = [&]() {
if constexpr(IsFwd)
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
}
else
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
}
}();
return randval_dram_window; const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
} auto randval_dram_window = [&]() {
else if constexpr(IsFwd)
{ {
(void)randval_dram_block_window_tmp; return make_tile_window(
(void)seqlen_qk_start; randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
}
else
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
}
}();
return make_null_tile_window(make_tuple(number<0>{}, number<0>{})); return randval_dram_window;
}
} }
template <typename BlockGemm> template <typename BlockGemm>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment