Commit 54dfedcd authored by guangzlu's avatar guangzlu
Browse files

added switch for lse storing in attn fwd

parent d2eed8e6
......@@ -416,6 +416,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
template <bool HasMainKBlockLoop,
bool IsDropout,
bool IsLseStoring,
typename Block2CTileMap,
typename C0MatrixMask>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
......@@ -1019,7 +1020,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}
else
{
......@@ -1149,22 +1149,26 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// Calculate max + ln(sum) and write out
static_for<0, MXdlPerWave, 1>{}(
[&](auto I) { lse_thread_buf(I) = running_max(I) + math::log(running_sum(I)); });
if(get_warp_local_1d_id() < AccM2)
if constexpr(IsLseStoring)
{
static_for<0, MXdlPerWave, 1>{}([&](auto I) {
// copy from VGPR to Global
lse_thread_copy_vgpr_to_global.Run(lse_thread_desc_mblock_mrepeat_mwave_mperxdl,
make_tuple(I0, Number<I>{}, I0, I0),
lse_thread_buf,
lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
lse_grid_buf);
lse_thread_copy_vgpr_to_global.MoveDstSliceWindow(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl, make_multi_index(0, 1, 0, 0));
});
static_for<0, MXdlPerWave, 1>{}(
[&](auto I) { lse_thread_buf(I) = running_max(I) + math::log(running_sum(I)); });
if(get_warp_local_1d_id() < AccM2)
{
static_for<0, MXdlPerWave, 1>{}([&](auto I) {
// copy from VGPR to Global
lse_thread_copy_vgpr_to_global.Run(lse_thread_desc_mblock_mrepeat_mwave_mperxdl,
make_tuple(I0, Number<I>{}, I0, I0),
lse_thread_buf,
lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
lse_grid_buf);
lse_thread_copy_vgpr_to_global.MoveDstSliceWindow(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl, make_multi_index(0, 1, 0, 0));
});
}
}
// shuffle C and write out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment