Unverified Commit bd1ae40f authored by rocking5566's avatar rocking5566 Committed by GitHub
Browse files

Merge branch 'develop' into gemm_layernorm_welford

parents 78ff5f81 d1567094
...@@ -49,7 +49,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -49,7 +49,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
int BatchStrideB0 = -1, int BatchStrideB0 = -1,
int BatchStrideB1 = -1, int BatchStrideB1 = -1,
int BatchStrideC = -1, int BatchStrideC = -1,
float alpha = 1.f) float alpha = -1.f)
{ {
...@@ -187,6 +187,10 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -187,6 +187,10 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
if(alpha < 0)
{
alpha = 1.f / std::sqrt(K); // usually 1 / sqrt(head_dim)
}
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{alpha}; auto acc0_element_op = Acc0ElementOp{alpha};
......
...@@ -45,7 +45,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification, ...@@ -45,7 +45,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
int O, int O,
int G0, int G0,
int G1, int G1,
float alpha = 1.f) float alpha = -1.f)
{ {
...@@ -154,6 +154,10 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification, ...@@ -154,6 +154,10 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
if(alpha < 0)
{
alpha = 1.f / std::sqrt(K); // usually 1 / sqrt(head_dim)
}
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{alpha}; auto acc0_element_op = Acc0ElementOp{alpha};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment