Commit 81157e04 authored by ltqin's avatar ltqin
Browse files

host dK and dQ correct

parent bc9e2f25
......@@ -210,7 +210,8 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ushort p_dropout)
ushort p_dropout_in_16bits,
float rp_dropout)
{
// S = alpha * Q * K^T
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
......@@ -241,7 +242,8 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
// P_dropout
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout);
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// std::cout << "p_drop_g_m_n ref:\n" << p_drop_g_m_n;
......@@ -264,8 +266,8 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 128;
ck::index_t N = 256;
ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t K = 128;
ck::index_t O = 128;
ck::index_t G0 = 1;
......@@ -474,7 +476,8 @@ int run(int argc, char* argv[])
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits);
p_dropout_in_16bits,
rp_dropout);
y_gs_ms_os.ForEach(
[&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); });
......@@ -589,7 +592,8 @@ int run(int argc, char* argv[])
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits);
p_dropout_in_16bits,
rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
});
......@@ -633,19 +637,19 @@ int run(int argc, char* argv[])
#if PRINT_HOST
{
std::cout << "===== dP = dY * V^T\n";
std::cout << "ygrad_drop_g_m_o ref:\n" << pgrad_drop_g_m_n;
std::cout << "ygrad_drop_g_m_o ref:\n" << ygrad_drop_g_m_n;
std::cout << "v_g_o_n ref:\n" << v_g_o_n;
std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n;
std::cout << "pgrad_drop_g_m_n ref:\n" << pgrad_drop_g_m_n;
}
#endif
// dP = dP_dropout . Z
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout);
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
float ygrad_dot_y = 0;
for(int o = 0; o < O; o++)
......@@ -666,31 +670,22 @@ int run(int argc, char* argv[])
std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
}
#endif
// dV = rp_dropout * P_drop^T * dY
auto pdrop_g_n_m = p_drop_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{pdrop_g_n_m,
ygrad_g_m_o,
vgrad_g_n_o,
PassThrough{},
PassThrough{},
Scale{rp_dropout}});
// dV = P_drop^T * dY
auto p_drop_g_n_m = p_drop_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}});
#if PRINT_HOST
{
std::cout << "===== dV = P^T * dY\n";
std::cout << "pdrop_g_n_m ref:\n" << pdrop_g_n_m;
std::cout << "p_drop_g_n_m ref:\n" << p_drop_g_n_m;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o;
}
#endif
// dQ = alpha * dS * K
ref_gemm_grad_invoker.Run(RefGemmGradArg{sgrad_g_m_n,
k_g_n_k,
qgrad_g_m_k,
PassThrough{},
PassThrough{},
Scale{alpha * rp_dropout}});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
sgrad_g_m_n, k_g_n_k, qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
#if PRINT_HOST
{
std::cout << "===== dQ = alpha * dS * K\n";
......@@ -702,12 +697,8 @@ int run(int argc, char* argv[])
// dK = alpha * dS^T * Q
auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{sgrad_g_n_m,
q_g_m_k,
kgrad_g_n_k,
PassThrough{},
PassThrough{},
Scale{alpha * rp_dropout}});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
sgrad_g_n_m, q_g_m_k, kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
#if PRINT_HOST
{
std::cout << "===== dK = alpha * dS^T * Q\n";
......
......@@ -1900,7 +1900,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
"");
// TODO: tune gemm2 pipeline
// dV = P^T * dY
// dV = P_drop^T * dY
v_slash_k_grad_thread_buf.Clear();
static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dV
// load VGrad Gemm B
......
......@@ -25,14 +25,20 @@ struct ReferenceDropout : public device::BaseOperator
Argument(const Tensor<RefDataType>& ref,
const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
RefDataType p_dropout)
: ref_(ref), in_(in), out_(out), p_dropout_(p_dropout)
RefDataType p_dropout_in_16bits,
float rp_dropout)
: ref_(ref),
in_(in),
out_(out),
p_dropout_in_16bits_(p_dropout_in_16bits),
rp_dropout_(ck::type_convert<OutDataType>(rp_dropout))
{
}
const Tensor<RefDataType>& ref_;
const Tensor<InDataType>& in_;
Tensor<OutDataType>& out_;
RefDataType p_dropout_;
RefDataType p_dropout_in_16bits_;
OutDataType rp_dropout_;
};
// Invoker
......@@ -41,7 +47,8 @@ struct ReferenceDropout : public device::BaseOperator
float Run(const Argument& arg)
{
arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) = arg.ref_(idx) < arg.p_dropout_ ? arg.in_(idx) : 0;
self(idx) =
arg.ref_(idx) < arg.p_dropout_in_16bits_ ? arg.in_(idx) * arg.rp_dropout_ : 0;
});
return 0;
}
......@@ -64,9 +71,10 @@ struct ReferenceDropout : public device::BaseOperator
static auto MakeArgument(const Tensor<RefDataType>& ref,
const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
RefDataType p_dropout)
RefDataType p_dropout_in_16bits,
float rp_dropout)
{
return Argument{ref, in, out, p_dropout};
return Argument{ref, in, out, p_dropout_in_16bits, rp_dropout};
}
static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment