Commit 81157e04 authored by ltqin's avatar ltqin
Browse files

host dK and dQ correct

parent bc9e2f25
...@@ -210,7 +210,8 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -210,7 +210,8 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m, TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n, TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n, TensorZ& z_g_m_n,
ushort p_dropout) ushort p_dropout_in_16bits,
float rp_dropout)
{ {
// S = alpha * Q * K^T // S = alpha * Q * K^T
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1}); auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
...@@ -241,7 +242,8 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -241,7 +242,8 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
// P_dropout // P_dropout
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout); auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// std::cout << "p_drop_g_m_n ref:\n" << p_drop_g_m_n; // std::cout << "p_drop_g_m_n ref:\n" << p_drop_g_m_n;
...@@ -264,8 +266,8 @@ int run(int argc, char* argv[]) ...@@ -264,8 +266,8 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 128; ck::index_t M = 512;
ck::index_t N = 256; ck::index_t N = 512;
ck::index_t K = 128; ck::index_t K = 128;
ck::index_t O = 128; ck::index_t O = 128;
ck::index_t G0 = 1; ck::index_t G0 = 1;
...@@ -474,7 +476,8 @@ int run(int argc, char* argv[]) ...@@ -474,7 +476,8 @@ int run(int argc, char* argv[])
lse_g_m, lse_g_m,
p_drop_g_m_n, p_drop_g_m_n,
z_g_m_n, z_g_m_n,
p_dropout_in_16bits); p_dropout_in_16bits,
rp_dropout);
y_gs_ms_os.ForEach( y_gs_ms_os.ForEach(
[&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); }); [&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); });
...@@ -589,7 +592,8 @@ int run(int argc, char* argv[]) ...@@ -589,7 +592,8 @@ int run(int argc, char* argv[])
lse_g_m, lse_g_m,
p_drop_g_m_n, p_drop_g_m_n,
z_g_m_n, z_g_m_n,
p_dropout_in_16bits); p_dropout_in_16bits,
rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) { y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
}); });
...@@ -633,19 +637,19 @@ int run(int argc, char* argv[]) ...@@ -633,19 +637,19 @@ int run(int argc, char* argv[])
#if PRINT_HOST #if PRINT_HOST
{ {
std::cout << "===== dP = dY * V^T\n"; std::cout << "===== dP = dY * V^T\n";
std::cout << "ygrad_drop_g_m_o ref:\n" << pgrad_drop_g_m_n; std::cout << "ygrad_drop_g_m_o ref:\n" << ygrad_drop_g_m_n;
std::cout << "v_g_o_n ref:\n" << v_g_o_n; std::cout << "v_g_o_n ref:\n" << v_g_o_n;
std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n; std::cout << "pgrad_drop_g_m_n ref:\n" << pgrad_drop_g_m_n;
} }
#endif #endif
// dP = dP_dropout . Z // dP = dP_dropout . Z
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = auto ref_dropout_argment = ref_dropout.MakeArgument(
ref_dropout.MakeArgument(z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout); z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) { sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
float ygrad_dot_y = 0; float ygrad_dot_y = 0;
for(int o = 0; o < O; o++) for(int o = 0; o < O; o++)
...@@ -666,31 +670,22 @@ int run(int argc, char* argv[]) ...@@ -666,31 +670,22 @@ int run(int argc, char* argv[])
std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n; std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
} }
#endif #endif
// dV = P_drop^T * dY
// dV = rp_dropout * P_drop^T * dY auto p_drop_g_n_m = p_drop_g_m_n.Transpose({0, 2, 1});
auto pdrop_g_n_m = p_drop_g_m_n.Transpose({0, 2, 1}); ref_gemm_grad_invoker.Run(RefGemmGradArg{
ref_gemm_grad_invoker.Run(RefGemmGradArg{pdrop_g_n_m, p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}});
ygrad_g_m_o,
vgrad_g_n_o,
PassThrough{},
PassThrough{},
Scale{rp_dropout}});
#if PRINT_HOST #if PRINT_HOST
{ {
std::cout << "===== dV = P^T * dY\n"; std::cout << "===== dV = P^T * dY\n";
std::cout << "pdrop_g_n_m ref:\n" << pdrop_g_n_m; std::cout << "p_drop_g_n_m ref:\n" << p_drop_g_n_m;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o; std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o; std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o;
} }
#endif #endif
// dQ = alpha * dS * K // dQ = alpha * dS * K
ref_gemm_grad_invoker.Run(RefGemmGradArg{sgrad_g_m_n, ref_gemm_grad_invoker.Run(RefGemmGradArg{
k_g_n_k, sgrad_g_m_n, k_g_n_k, qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
qgrad_g_m_k,
PassThrough{},
PassThrough{},
Scale{alpha * rp_dropout}});
#if PRINT_HOST #if PRINT_HOST
{ {
std::cout << "===== dQ = alpha * dS * K\n"; std::cout << "===== dQ = alpha * dS * K\n";
...@@ -702,12 +697,8 @@ int run(int argc, char* argv[]) ...@@ -702,12 +697,8 @@ int run(int argc, char* argv[])
// dK = alpha * dS^T * Q // dK = alpha * dS^T * Q
auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1}); auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{sgrad_g_n_m, ref_gemm_grad_invoker.Run(RefGemmGradArg{
q_g_m_k, sgrad_g_n_m, q_g_m_k, kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
kgrad_g_n_k,
PassThrough{},
PassThrough{},
Scale{alpha * rp_dropout}});
#if PRINT_HOST #if PRINT_HOST
{ {
std::cout << "===== dK = alpha * dS^T * Q\n"; std::cout << "===== dK = alpha * dS^T * Q\n";
......
...@@ -1900,7 +1900,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1900,7 +1900,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
""); "");
// TODO: tune gemm2 pipeline // TODO: tune gemm2 pipeline
// dV = P^T * dY // dV = P_drop^T * dY
v_slash_k_grad_thread_buf.Clear(); v_slash_k_grad_thread_buf.Clear();
static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dV static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dV
// load VGrad Gemm B // load VGrad Gemm B
......
...@@ -25,14 +25,20 @@ struct ReferenceDropout : public device::BaseOperator ...@@ -25,14 +25,20 @@ struct ReferenceDropout : public device::BaseOperator
Argument(const Tensor<RefDataType>& ref, Argument(const Tensor<RefDataType>& ref,
const Tensor<InDataType>& in, const Tensor<InDataType>& in,
Tensor<OutDataType>& out, Tensor<OutDataType>& out,
RefDataType p_dropout) RefDataType p_dropout_in_16bits,
: ref_(ref), in_(in), out_(out), p_dropout_(p_dropout) float rp_dropout)
: ref_(ref),
in_(in),
out_(out),
p_dropout_in_16bits_(p_dropout_in_16bits),
rp_dropout_(ck::type_convert<OutDataType>(rp_dropout))
{ {
} }
const Tensor<RefDataType>& ref_; const Tensor<RefDataType>& ref_;
const Tensor<InDataType>& in_; const Tensor<InDataType>& in_;
Tensor<OutDataType>& out_; Tensor<OutDataType>& out_;
RefDataType p_dropout_; RefDataType p_dropout_in_16bits_;
OutDataType rp_dropout_;
}; };
// Invoker // Invoker
...@@ -41,7 +47,8 @@ struct ReferenceDropout : public device::BaseOperator ...@@ -41,7 +47,8 @@ struct ReferenceDropout : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
arg.out_.ForEach([&](auto& self, auto idx) { arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) = arg.ref_(idx) < arg.p_dropout_ ? arg.in_(idx) : 0; self(idx) =
arg.ref_(idx) < arg.p_dropout_in_16bits_ ? arg.in_(idx) * arg.rp_dropout_ : 0;
}); });
return 0; return 0;
} }
...@@ -64,9 +71,10 @@ struct ReferenceDropout : public device::BaseOperator ...@@ -64,9 +71,10 @@ struct ReferenceDropout : public device::BaseOperator
static auto MakeArgument(const Tensor<RefDataType>& ref, static auto MakeArgument(const Tensor<RefDataType>& ref,
const Tensor<InDataType>& in, const Tensor<InDataType>& in,
Tensor<OutDataType>& out, Tensor<OutDataType>& out,
RefDataType p_dropout) RefDataType p_dropout_in_16bits,
float rp_dropout)
{ {
return Argument{ref, in, out, p_dropout}; return Argument{ref, in, out, p_dropout_in_16bits, rp_dropout};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment