Commit 0db65bfb authored by guangzlu's avatar guangzlu
Browse files

modified judgement of dropout

parent 7e402f6a
...@@ -44,7 +44,7 @@ struct BlockwiseDropout ...@@ -44,7 +44,7 @@ struct BlockwiseDropout
static_for<0, KRepeat, 1>{}([&](auto iK) { static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset)); execute_dropout(tmp[tmp_index] < p_dropout_16bits, in_thread_buf(offset));
tmp_index = tmp_index + 1; tmp_index = tmp_index + 1;
}); });
}); });
...@@ -79,7 +79,7 @@ struct BlockwiseDropout ...@@ -79,7 +79,7 @@ struct BlockwiseDropout
static_for<0, KRepeat, 1>{}([&](auto iK) { static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset)); execute_dropout(tmp[tmp_index] < p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp[tmp_index]; z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1; tmp_index = tmp_index + 1;
}); });
...@@ -117,7 +117,7 @@ struct BlockwiseDropout ...@@ -117,7 +117,7 @@ struct BlockwiseDropout
constexpr auto iOffset = Number<tmp_size>{} * Offset{}; constexpr auto iOffset = Number<tmp_size>{} * Offset{};
static_for<0, tmp_size, 1>{}([&](auto i) { static_for<0, tmp_size, 1>{}([&](auto i) {
in_thread_buf(i + iOffset) = in_thread_buf(i + iOffset) =
execute_dropout(tmp[i.value] <= p_dropout_16bits, in_thread_buf(i + iOffset)); execute_dropout(tmp[i.value] < p_dropout_16bits, in_thread_buf(i + iOffset));
z_thread_buf(i) = tmp[i.value]; z_thread_buf(i) = tmp[i.value];
}); });
} }
......
...@@ -46,13 +46,22 @@ struct ReferenceDropout : public device::BaseOperator ...@@ -46,13 +46,22 @@ struct ReferenceDropout : public device::BaseOperator
{ {
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
arg.out_.ForEach([&](auto& self, auto idx) { if(arg.p_dropout_in_16bits_ < 65535)
self(idx) = {
arg.ref_(idx) <= arg.p_dropout_in_16bits_ arg.out_.ForEach([&](auto& self, auto idx) {
? ck::type_convert<OutDataType>(ck::type_convert<float>(arg.in_(idx)) * self(idx) = arg.ref_(idx) < arg.p_dropout_in_16bits_
ck::type_convert<float>(arg.rp_dropout_)) ? ck::type_convert<OutDataType>(
: 0; ck::type_convert<float>(arg.in_(idx)) *
}); ck::type_convert<float>(arg.rp_dropout_))
: 0;
});
}
else
{
arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) = ck::type_convert<OutDataType>(arg.in_(idx));
});
}
return 0; return 0;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment