Commit 90e99a95 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Let BlockDropout reuse LDS with V

parent 281110cf
...@@ -533,9 +533,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -533,9 +533,8 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
auto randval_ptr = reinterpret_cast<char*>(smem_ptr) + auto randval_ptr =
Policy::template GetSmemSizeK<Problem>() + reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>();
Policy::template GetSmemSizeV<Problem>();
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>( dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
} }
......
...@@ -177,8 +177,10 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy ...@@ -177,8 +177,10 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
// assume Q can reuse the shared memory with K or V // assume Q can reuse the shared memory with K or V
return max(GetSmemSizeQ<Problem>(), GetSmemSizeK<Problem>() + GetSmemSizeV<Problem>()) + // assume Dropout can reuse the shared memory with V
GetSmemSizeDropout<Problem>(0); return max(GetSmemSizeQ<Problem>(),
GetSmemSizeK<Problem>() +
max(GetSmemSizeV<Problem>(), GetSmemSizeDropout<Problem>(0)));
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment