Commit 4388b767 authored by danyao12's avatar danyao12
Browse files

optimize fwd v2 dropout to eliminate scratch

parent 00cb7e41
...@@ -198,55 +198,41 @@ struct BlockwiseDropout ...@@ -198,55 +198,41 @@ struct BlockwiseDropout
}); });
} }
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false> template <typename CThreadBuffer, typename ZThreadBuffer, typename Step, typename Offset>
__host__ __device__ void ApplyDropoutWithZ(CThreadBuffer& in_thread_buf, __host__ __device__ void ApplyDropoutWithZ(CThreadBuffer& in_thread_buf,
ZThreadBuffer& z_thread_buf) ZThreadBuffer& z_thread_buf)
{ {
auto execute_dropout = [&](bool keep, DataType val) { auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit) return keep ? val * p_dropout_rescale : float(0);
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
}; };
int tmp_index = 0; constexpr int tmp_size = MRepeat * KRepeat / Step{}.value;
static_for<0, MRepeat, 1>{}([&](auto iM) { static_for<0, tmp_size, 1>{}([&](auto i) {
static_for<0, KRepeat, 1>{}([&](auto iK) { in_thread_buf(i + Offset{}) =
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; execute_dropout(z_thread_buf(i) <= p_dropout_16bits, in_thread_buf(i + Offset{}));
in_thread_buf(offset) = execute_dropout(z_thread_buf(offset) <= p_dropout_16bits,
in_thread_buf(offset));
tmp_index = tmp_index + 1;
});
}); });
} }
// get raw z matrix with random number for shuffle // get raw z matrix with random number for shuffle
template <typename ZThreadBuffer> template <typename ZThreadBuffer,
typename Step,
typename Offset> // N3*N4=8
__host__ __device__ void GenerateZMatrixAttnFwd(ck::philox& ph, __host__ __device__ void GenerateZMatrixAttnFwd(ck::philox& ph,
index_t element_global_1d_id, index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf) ZThreadBuffer& z_thread_buf)
{ {
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat / Step{}.value;
int philox_calls = tmp_size / 4; int philox_calls = tmp_size / 4;
ushort tmp[tmp_size]; ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8); ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{});
} }
block_sync_lds(); static_for<0, tmp_size, 1>{}([&](auto i) { z_thread_buf(i) = tmp[i.value]; });
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
} }
ushort p_dropout_16bits; ushort p_dropout_16bits;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment