Commit 90e99a95 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Let BlockDropout reuse LDS with V

parent 281110cf
......@@ -533,9 +533,8 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(kHasDropout)
{
auto randval_ptr = reinterpret_cast<char*>(smem_ptr) +
Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeV<Problem>();
auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>();
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
}
......
......@@ -177,8 +177,10 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
// assume Q can reuse the shared memory with K or V
return max(GetSmemSizeQ<Problem>(), GetSmemSizeK<Problem>() + GetSmemSizeV<Problem>()) +
GetSmemSizeDropout<Problem>(0);
// assume Dropout can reuse the shared memory with V
return max(GetSmemSizeQ<Problem>(),
GetSmemSizeK<Problem>() +
max(GetSmemSizeV<Problem>(), GetSmemSizeDropout<Problem>(0)));
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment