Commit 0e308145 authored by danyao12's avatar danyao12
Browse files

fix grouped bwd example host issue

parent 665b08cf
...@@ -948,9 +948,12 @@ int run(int argc, char* argv[]) ...@@ -948,9 +948,12 @@ int run(int argc, char* argv[])
{ {
auto idx_gmo = idx_gmn; auto idx_gmo = idx_gmn;
idx_gmo[2] = o; idx_gmo[2] = o;
ygrad_dot_y += ygrad_g_m_o(idx_gmo) * y_g_m_os[i](idx_gmo); ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) *
ck::type_convert<AccDataType>(y_g_m_os[i](idx_gmo));
} }
self(idx_gmn) = p_g_m_ns[i](idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y); self(idx_gmn) = ck::type_convert<DataType>(
ck::type_convert<AccDataType>(p_g_m_ns[i](idx_gmn)) *
(ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y));
}); });
auto p_drop_g_n_m = p_drop_g_m_ns[i].Transpose({0, 2, 1}); auto p_drop_g_n_m = p_drop_g_m_ns[i].Transpose({0, 2, 1});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment