Commit 4388b767 authored by danyao12's avatar danyao12
Browse files

optimize fwd v2 dropout to eliminate scratch

parent 00cb7e41
......@@ -198,55 +198,41 @@ struct BlockwiseDropout
});
}
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
template <typename CThreadBuffer, typename ZThreadBuffer, typename Step, typename Offset>
__host__ __device__ void ApplyDropoutWithZ(CThreadBuffer& in_thread_buf,
ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
return keep ? val * p_dropout_rescale : float(0);
};
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = execute_dropout(z_thread_buf(offset) <= p_dropout_16bits,
in_thread_buf(offset));
tmp_index = tmp_index + 1;
});
constexpr int tmp_size = MRepeat * KRepeat / Step{}.value;
static_for<0, tmp_size, 1>{}([&](auto i) {
in_thread_buf(i + Offset{}) =
execute_dropout(z_thread_buf(i) <= p_dropout_16bits, in_thread_buf(i + Offset{}));
});
}
// get raw z matrix with random number for shuffle
template <typename ZThreadBuffer>
template <typename ZThreadBuffer,
typename Step,
typename Offset> // N3*N4=8
__host__ __device__ void GenerateZMatrixAttnFwd(ck::philox& ph,
index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf)
{
constexpr int tmp_size = MRepeat * KRepeat;
constexpr int tmp_size = MRepeat * KRepeat / Step{}.value;
int philox_calls = tmp_size / 4;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{});
}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
static_for<0, tmp_size, 1>{}([&](auto i) { z_thread_buf(i) = tmp[i.value]; });
}
ushort p_dropout_16bits;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment