Commit 4347bab1 authored by ltqin's avatar ltqin
Browse files

fix P computer

parent bb06d009
...@@ -1609,6 +1609,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1609,6 +1609,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const index_t num_gemm1_k_block_outer_loop = k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock; const index_t num_gemm1_k_block_outer_loop = k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock; constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
const index_t K = k_grid_desc_k0_n_k1.GetLength(I0) * k_grid_desc_k0_n_k1.GetLength(I2);
const float scale = 1.0f / std::sqrt(K);
// Initialize dQ // Initialize dQ
qgrad_thread_buf.Clear(); qgrad_thread_buf.Clear();
...@@ -1689,14 +1691,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1689,14 +1691,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
} }
else else
{ {
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]); s_slash_p_thread_buf(i) = scale * s_slash_p_thread_buf[i];
} }
}); });
} }
else else
{ {
static_for<0, s_slash_p_thread_buf.Size(), 1>{}( static_for<0, s_slash_p_thread_buf.Size(), 1>{}(
[&](auto i) { s_element_op(acc_thread_buf(i), s_slash_p_thread_buf[i]); }); [&](auto i) { s_slash_p_thread_buf(i) = scale * s_slash_p_thread_buf[i]; });
} }
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment