Commit f9bb62d5 authored by fsx950223's avatar fsx950223
Browse files

fix bugs

parent 1a1a8924
......@@ -399,6 +399,7 @@ int run(int argc, char* argv[])
std::vector<Tensor<AccDataType>> s_g_m_ns;
std::vector<Tensor<DataType>> p_g_m_ns;
std::vector<Tensor<DataType>> y_g_m_os;
std::vector<Tensor<LSEDataType>> lse_g_ms;
std::vector<Tensor<DataType>> p_drop_g_m_ns;
std::vector<Tensor<DataType>> q_tensors;
......@@ -422,7 +423,7 @@ int run(int argc, char* argv[])
std::vector<DeviceMemPtr> ygrad_tensors_device;
std::vector<DeviceMemPtr> kgrad_tensors_device;
std::vector<DeviceMemPtr> vgrad_tensors_device;
std::size_t group_count = 1;
std::size_t group_count = 3;
std::size_t flop = 0, num_byte = 0;
for(std::size_t i = 0; i < group_count; i++)
{
......@@ -624,6 +625,7 @@ int run(int argc, char* argv[])
s_g_m_ns.push_back(s_g_m_n);
p_g_m_ns.push_back(p_g_m_n);
y_g_m_os.push_back(y_g_m_o);
lse_g_ms.push_back(lse_g_m);
q_tensors.push_back(q_gs_ms_ks);
k_tensors.push_back(k_gs_ns_ks);
v_tensors.push_back(v_gs_os_ns);
......@@ -718,8 +720,28 @@ int run(int argc, char* argv[])
bool pass = true;
if(do_verification)
{
for(int i = 0; i < group_count; i++)
{
int G1 = v_tensors[i].GetLengths()[1];
z_tensors_device[i]->FromDevice(z_g_m_ns[i].data());
run_attention_fwd_host(q_g_m_ks[i],
k_g_n_ks[i],
v_g_n_os[i],
alpha,
s_g_m_ns[i],
p_g_m_ns[i],
y_g_m_os[i],
lse_g_ms[i],
p_drop_g_m_ns[i],
z_g_m_ns[i],
p_dropout_in_16bits,
rp_dropout);
y_tensors[i].ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_os[i](idx[0] * G1 + idx[1], idx[2], idx[3]);
});
y_tensors_device[i]->ToDevice(y_tensors[i].data());
qgrad_tensors_device[i]->SetZero();
kgrad_tensors_device[i]->SetZero();
vgrad_tensors_device[i]->SetZero();
......@@ -799,6 +821,7 @@ int run(int argc, char* argv[])
}
self(idx_gmn) = p_g_m_ns[i](idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y);
});
auto p_drop_g_n_m = p_drop_g_m_ns[i].Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}});
......
......@@ -681,6 +681,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Zs.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_B1s.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Cs.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Ygrads.size()) &&
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment