Commit 2464edd0 authored by letaoqin's avatar letaoqin
Browse files

add comments for grouped host code

parent 28459058
...@@ -767,12 +767,13 @@ int run(int argc, char* argv[]) ...@@ -767,12 +767,13 @@ int run(int argc, char* argv[])
auto v_g_o_n = v_g_n_os[i].Transpose({0, 2, 1}); auto v_g_o_n = v_g_n_os[i].Transpose({0, 2, 1});
ref_gemm0_grad_invoker.Run(RefGemm0GradArg{ ref_gemm0_grad_invoker.Run(RefGemm0GradArg{
ygrad_g_m_o, v_g_o_n, pgrad_drop_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}}); ygrad_g_m_o, v_g_o_n, pgrad_drop_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
// dP = dP_dropout x Z
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument( auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout); z_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) { sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
float ygrad_dot_y = 0; float ygrad_dot_y = 0;
for(int o = 0; o < O; o++) for(int o = 0; o < O; o++)
...@@ -786,12 +787,14 @@ int run(int argc, char* argv[]) ...@@ -786,12 +787,14 @@ int run(int argc, char* argv[])
ck::type_convert<AccDataType>(p_g_m_ns[i](idx_gmn)) * ck::type_convert<AccDataType>(p_g_m_ns[i](idx_gmn)) *
(ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y)); (ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y));
}); });
// dV = P_drop^T * dY
auto p_drop_g_n_m = p_drop_g_m_ns[i].Transpose({0, 2, 1}); auto p_drop_g_n_m = p_drop_g_m_ns[i].Transpose({0, 2, 1});
ref_gemm1_grad_invoker.Run(RefGemm1GradArg{ ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}}); p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}});
// dQ = alpha * dS * K
ref_gemm1_grad_invoker.Run(RefGemm1GradArg{ ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
sgrad_g_m_n, k_g_n_ks[i], qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}}); sgrad_g_m_n, k_g_n_ks[i], qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
// dK = alpha * dS^T * Q
auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1}); auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
ref_gemm1_grad_invoker.Run(RefGemm1GradArg{ ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
sgrad_g_n_m, q_g_m_ks[i], kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}}); sgrad_g_n_m, q_g_m_ks[i], kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment