Commit 6d220ec8 authored by fsx950223's avatar fsx950223
Browse files

fix a bug

parent 8e3c6991
......@@ -136,6 +136,9 @@ __global__ void
ignore = acc_element_op;
ignore = b1_element_op;
ignore = c_element_op;
ignore = p_dropout;
ignore = seed;
ignore = offset;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -668,12 +671,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op}
c_element_op_{c_element_op},
p_dropout_{p_drop}
{
p_dropout_ = 1.f - p_drop;
float rp_dropout_ = 1.f / p_dropout_;
acc_element_op_.Append(rp_dropout_);
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment