Commit ac3c1563 authored by ltqin's avatar ltqin
Browse files

change if to #if

parent 61f4a7ee
...@@ -523,13 +523,14 @@ int run(int argc, char* argv[]) ...@@ -523,13 +523,14 @@ int run(int argc, char* argv[])
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
if(PRINT_HOST) #if PRINT_HOST
{ {
std::cout << "q_g_m_k ref:\n" << q_g_m_k; std::cout << "q_g_m_k ref:\n" << q_g_m_k;
std::cout << "k_g_n_k ref:\n" << k_g_n_k; std::cout << "k_g_n_k ref:\n" << k_g_n_k;
std::cout << "v_g_n_o ref:\n" << v_g_n_o; std::cout << "v_g_n_o ref:\n" << v_g_n_o;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o; std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
} }
#endif
// Gradients // Gradients
auto ref_gemm_grad = ReferenceGemmGradInstance{}; auto ref_gemm_grad = ReferenceGemmGradInstance{};
...@@ -540,13 +541,14 @@ int run(int argc, char* argv[]) ...@@ -540,13 +541,14 @@ int run(int argc, char* argv[])
auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1}); auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm_grad_invoker.Run(RefGemmGradArg{
ygrad_g_m_o, v_g_o_n, pgrad_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}}); ygrad_g_m_o, v_g_o_n, pgrad_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
if(PRINT_HOST) #if PRINT_HOST
{ {
std::cout << "===== dP = dY * V^T\n"; std::cout << "===== dP = dY * V^T\n";
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o; std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "v_g_o_n ref:\n" << v_g_o_n; std::cout << "v_g_o_n ref:\n" << v_g_o_n;
std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n; std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n;
} }
#endif
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) { sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
...@@ -559,7 +561,7 @@ int run(int argc, char* argv[]) ...@@ -559,7 +561,7 @@ int run(int argc, char* argv[])
} }
self(idx_gmn) = p_g_m_n(idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y); self(idx_gmn) = p_g_m_n(idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y);
}); });
if(PRINT_HOST) #if PRINT_HOST
{ {
std::cout << "===== dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)\n"; std::cout << "===== dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)\n";
std::cout << "p_g_m_n ref:\n" << p_g_m_n; std::cout << "p_g_m_n ref:\n" << p_g_m_n;
...@@ -568,41 +570,45 @@ int run(int argc, char* argv[]) ...@@ -568,41 +570,45 @@ int run(int argc, char* argv[])
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o; std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n; std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
} }
#endif
// dV = P^T * dY // dV = P^T * dY
auto p_g_n_m = p_g_m_n.Transpose({0, 2, 1}); auto p_g_n_m = p_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm_grad_invoker.Run(RefGemmGradArg{
p_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.f}}); p_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.f}});
if(PRINT_HOST) #if PRINT_HOST
{ {
std::cout << "===== dV = P^T * dY\n"; std::cout << "===== dV = P^T * dY\n";
std::cout << "p_g_n_m ref:\n" << p_g_n_m; std::cout << "p_g_n_m ref:\n" << p_g_n_m;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o; std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o; std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o;
} }
#endif
// dQ = alpha * dS * K // dQ = alpha * dS * K
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm_grad_invoker.Run(RefGemmGradArg{
sgrad_g_m_n, k_g_n_k, qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}}); sgrad_g_m_n, k_g_n_k, qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
if(PRINT_HOST) #if PRINT_HOST
{ {
std::cout << "===== dQ = alpha * dS * K\n"; std::cout << "===== dQ = alpha * dS * K\n";
std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n; std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
std::cout << "k_g_n_k ref:\n" << k_g_n_k; std::cout << "k_g_n_k ref:\n" << k_g_n_k;
std::cout << "qgrad_g_m_k ref:\n" << qgrad_g_m_k; std::cout << "qgrad_g_m_k ref:\n" << qgrad_g_m_k;
} }
#endif
// dK = alpha * dS^T * Q // dK = alpha * dS^T * Q
auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1}); auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm_grad_invoker.Run(RefGemmGradArg{
sgrad_g_n_m, q_g_m_k, kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}}); sgrad_g_n_m, q_g_m_k, kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
if(PRINT_HOST) #if PRINT_HOST
{ {
std::cout << "===== dK = alpha * dS^T * Q\n"; std::cout << "===== dK = alpha * dS^T * Q\n";
std::cout << "sgrad_g_n_m ref:\n" << sgrad_g_n_m; std::cout << "sgrad_g_n_m ref:\n" << sgrad_g_n_m;
std::cout << "q_g_m_k ref:\n" << q_g_m_k; std::cout << "q_g_m_k ref:\n" << q_g_m_k;
std::cout << "kgrad_g_n_k ref:\n" << kgrad_g_n_k; std::cout << "kgrad_g_n_k ref:\n" << kgrad_g_n_k;
} }
#endif
Tensor<DataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<DataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<DataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment