Commit 4003fefb authored by Anthony Chang's avatar Anthony Chang
Browse files

fix bug where scaling may not be applied in some code path

parent 8784a72e
......@@ -796,6 +796,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
});
}
else
{
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......
......@@ -45,7 +45,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
int O,
int G0,
int G1,
float alpha = 1.f)
float alpha = -1.f)
{
......@@ -154,6 +154,10 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
if(alpha < 0)
{
alpha = 1.f / std::sqrt(K); // usually 1 / sqrt(head_dim)
}
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{alpha};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment