Commit f9bb62d5 authored by fsx950223's avatar fsx950223
Browse files

fix bugs

parent 1a1a8924
...@@ -399,6 +399,7 @@ int run(int argc, char* argv[]) ...@@ -399,6 +399,7 @@ int run(int argc, char* argv[])
std::vector<Tensor<AccDataType>> s_g_m_ns; std::vector<Tensor<AccDataType>> s_g_m_ns;
std::vector<Tensor<DataType>> p_g_m_ns; std::vector<Tensor<DataType>> p_g_m_ns;
std::vector<Tensor<DataType>> y_g_m_os; std::vector<Tensor<DataType>> y_g_m_os;
std::vector<Tensor<LSEDataType>> lse_g_ms;
std::vector<Tensor<DataType>> p_drop_g_m_ns; std::vector<Tensor<DataType>> p_drop_g_m_ns;
std::vector<Tensor<DataType>> q_tensors; std::vector<Tensor<DataType>> q_tensors;
...@@ -422,7 +423,7 @@ int run(int argc, char* argv[]) ...@@ -422,7 +423,7 @@ int run(int argc, char* argv[])
std::vector<DeviceMemPtr> ygrad_tensors_device; std::vector<DeviceMemPtr> ygrad_tensors_device;
std::vector<DeviceMemPtr> kgrad_tensors_device; std::vector<DeviceMemPtr> kgrad_tensors_device;
std::vector<DeviceMemPtr> vgrad_tensors_device; std::vector<DeviceMemPtr> vgrad_tensors_device;
std::size_t group_count = 1; std::size_t group_count = 3;
std::size_t flop = 0, num_byte = 0; std::size_t flop = 0, num_byte = 0;
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
...@@ -624,6 +625,7 @@ int run(int argc, char* argv[]) ...@@ -624,6 +625,7 @@ int run(int argc, char* argv[])
s_g_m_ns.push_back(s_g_m_n); s_g_m_ns.push_back(s_g_m_n);
p_g_m_ns.push_back(p_g_m_n); p_g_m_ns.push_back(p_g_m_n);
y_g_m_os.push_back(y_g_m_o); y_g_m_os.push_back(y_g_m_o);
lse_g_ms.push_back(lse_g_m);
q_tensors.push_back(q_gs_ms_ks); q_tensors.push_back(q_gs_ms_ks);
k_tensors.push_back(k_gs_ns_ks); k_tensors.push_back(k_gs_ns_ks);
v_tensors.push_back(v_gs_os_ns); v_tensors.push_back(v_gs_os_ns);
...@@ -718,8 +720,28 @@ int run(int argc, char* argv[]) ...@@ -718,8 +720,28 @@ int run(int argc, char* argv[])
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
int G1 = v_tensors[i].GetLengths()[1];
z_tensors_device[i]->FromDevice(z_g_m_ns[i].data());
run_attention_fwd_host(q_g_m_ks[i],
k_g_n_ks[i],
v_g_n_os[i],
alpha,
s_g_m_ns[i],
p_g_m_ns[i],
y_g_m_os[i],
lse_g_ms[i],
p_drop_g_m_ns[i],
z_g_m_ns[i],
p_dropout_in_16bits,
rp_dropout);
y_tensors[i].ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_os[i](idx[0] * G1 + idx[1], idx[2], idx[3]);
});
y_tensors_device[i]->ToDevice(y_tensors[i].data());
qgrad_tensors_device[i]->SetZero(); qgrad_tensors_device[i]->SetZero();
kgrad_tensors_device[i]->SetZero(); kgrad_tensors_device[i]->SetZero();
vgrad_tensors_device[i]->SetZero(); vgrad_tensors_device[i]->SetZero();
...@@ -799,6 +821,7 @@ int run(int argc, char* argv[]) ...@@ -799,6 +821,7 @@ int run(int argc, char* argv[])
} }
self(idx_gmn) = p_g_m_ns[i](idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y); self(idx_gmn) = p_g_m_ns[i](idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y);
}); });
auto p_drop_g_n_m = p_drop_g_m_ns[i].Transpose({0, 2, 1}); auto p_drop_g_n_m = p_drop_g_m_ns[i].Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm_grad_invoker.Run(RefGemmGradArg{
p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}}); p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}});
......
...@@ -681,6 +681,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -681,6 +681,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) && if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Zs.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_B1s.size()) && group_count_ == ck::type_convert<ck::index_t>(p_B1s.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Cs.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Cs.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Ygrads.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Ygrads.size()) &&
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment