Commit 4347bab1 authored by ltqin's avatar ltqin
Browse files

fix P computer

parent bb06d009
......@@ -1609,6 +1609,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const index_t num_gemm1_k_block_outer_loop = k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
const index_t K = k_grid_desc_k0_n_k1.GetLength(I0) * k_grid_desc_k0_n_k1.GetLength(I2);
const float scale = 1.0f / std::sqrt(K);
// Initialize dQ
qgrad_thread_buf.Clear();
......@@ -1689,14 +1691,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
else
{
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
s_slash_p_thread_buf(i) = scale * s_slash_p_thread_buf[i];
}
});
}
else
{
static_for<0, s_slash_p_thread_buf.Size(), 1>{}(
[&](auto i) { s_element_op(acc_thread_buf(i), s_slash_p_thread_buf[i]); });
[&](auto i) { s_slash_p_thread_buf(i) = scale * s_slash_p_thread_buf[i]; });
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment