Commit 0db65bfb authored by guangzlu's avatar guangzlu
Browse files

modified judgement of dropout

parent 7e402f6a
......@@ -44,7 +44,7 @@ struct BlockwiseDropout
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
execute_dropout(tmp[tmp_index] < p_dropout_16bits, in_thread_buf(offset));
tmp_index = tmp_index + 1;
});
});
......@@ -79,7 +79,7 @@ struct BlockwiseDropout
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
execute_dropout(tmp[tmp_index] < p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
......@@ -117,7 +117,7 @@ struct BlockwiseDropout
constexpr auto iOffset = Number<tmp_size>{} * Offset{};
static_for<0, tmp_size, 1>{}([&](auto i) {
in_thread_buf(i + iOffset) =
execute_dropout(tmp[i.value] <= p_dropout_16bits, in_thread_buf(i + iOffset));
execute_dropout(tmp[i.value] < p_dropout_16bits, in_thread_buf(i + iOffset));
z_thread_buf(i) = tmp[i.value];
});
}
......
......@@ -46,13 +46,22 @@ struct ReferenceDropout : public device::BaseOperator
{
float Run(const Argument& arg)
{
arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) =
arg.ref_(idx) <= arg.p_dropout_in_16bits_
? ck::type_convert<OutDataType>(ck::type_convert<float>(arg.in_(idx)) *
ck::type_convert<float>(arg.rp_dropout_))
: 0;
});
if(arg.p_dropout_in_16bits_ < 65535)
{
arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) = arg.ref_(idx) < arg.p_dropout_in_16bits_
? ck::type_convert<OutDataType>(
ck::type_convert<float>(arg.in_(idx)) *
ck::type_convert<float>(arg.rp_dropout_))
: 0;
});
}
else
{
arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) = ck::type_convert<OutDataType>(arg.in_(idx));
});
}
return 0;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment