Commit 54dfedcd authored by guangzlu's avatar guangzlu
Browse files

added switch for lse storing in attn fwd

parent d2eed8e6
...@@ -416,6 +416,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -416,6 +416,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
bool IsDropout, bool IsDropout,
bool IsLseStoring,
typename Block2CTileMap, typename Block2CTileMap,
typename C0MatrixMask> typename C0MatrixMask>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
...@@ -1019,7 +1020,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1019,7 +1020,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} }
else else
{ {
...@@ -1149,22 +1149,26 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1149,22 +1149,26 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// Calculate max + ln(sum) and write out // Calculate max + ln(sum) and write out
static_for<0, MXdlPerWave, 1>{}(
[&](auto I) { lse_thread_buf(I) = running_max(I) + math::log(running_sum(I)); });
if(get_warp_local_1d_id() < AccM2) if constexpr(IsLseStoring)
{ {
static_for<0, MXdlPerWave, 1>{}([&](auto I) { static_for<0, MXdlPerWave, 1>{}(
// copy from VGPR to Global [&](auto I) { lse_thread_buf(I) = running_max(I) + math::log(running_sum(I)); });
lse_thread_copy_vgpr_to_global.Run(lse_thread_desc_mblock_mrepeat_mwave_mperxdl,
make_tuple(I0, Number<I>{}, I0, I0), if(get_warp_local_1d_id() < AccM2)
lse_thread_buf, {
lse_grid_desc_mblock_mrepeat_mwave_mperxdl, static_for<0, MXdlPerWave, 1>{}([&](auto I) {
lse_grid_buf); // copy from VGPR to Global
lse_thread_copy_vgpr_to_global.Run(lse_thread_desc_mblock_mrepeat_mwave_mperxdl,
lse_thread_copy_vgpr_to_global.MoveDstSliceWindow( make_tuple(I0, Number<I>{}, I0, I0),
lse_grid_desc_mblock_mrepeat_mwave_mperxdl, make_multi_index(0, 1, 0, 0)); lse_thread_buf,
}); lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
lse_grid_buf);
lse_thread_copy_vgpr_to_global.MoveDstSliceWindow(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl, make_multi_index(0, 1, 0, 0));
});
}
} }
// shuffle C and write out // shuffle C and write out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment