Commit 6d220ec8 authored by fsx950223's avatar fsx950223
Browse files

fix a bug

parent 8e3c6991
...@@ -136,6 +136,9 @@ __global__ void ...@@ -136,6 +136,9 @@ __global__ void
ignore = acc_element_op; ignore = acc_element_op;
ignore = b1_element_op; ignore = b1_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = p_dropout;
ignore = seed;
ignore = offset;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -668,12 +671,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -668,12 +671,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
b_element_op_{b_element_op}, b_element_op_{b_element_op},
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op},
p_dropout_{p_drop}
{ {
p_dropout_ = 1.f - p_drop;
float rp_dropout_ = 1.f / p_dropout_;
acc_element_op_.Append(rp_dropout_);
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment