Commit 06f575a3 authored by danyao12's avatar danyao12
Browse files

refactor dropout

parent 99436cd4
...@@ -8,8 +8,15 @@ ...@@ -8,8 +8,15 @@
namespace ck_tile { namespace ck_tile {
struct NullBlockDropout template <bool IsDropout_, bool IsWG32_, bool IsStoreRandval_>
struct BlockDropout;
template <bool IsWG32_, bool IsStoreRandval_>
struct BlockDropout<false, IsWG32_, IsStoreRandval_>
{ {
static constexpr bool IsDropout = false;
static constexpr bool IsStoreRandval = IsStoreRandval_;
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp> template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
...@@ -22,10 +29,10 @@ struct NullBlockDropout ...@@ -22,10 +29,10 @@ struct NullBlockDropout
} }
}; };
template <bool IsDropout_ = true, bool IsWG32_ = true, bool IsStoreRandval_ = false> template <bool IsWG32_, bool IsStoreRandval_>
struct BlockDropout struct BlockDropout<true, IsWG32_, IsStoreRandval_>
{ {
static constexpr bool IsDropout = IsDropout_; static constexpr bool IsDropout = true;
// true: 32*32 warp gemm // true: 32*32 warp gemm
// false: 16*16 warp gemm // false: 16*16 warp gemm
static constexpr bool IsWG32 = IsWG32_; static constexpr bool IsWG32 = IsWG32_;
......
...@@ -915,27 +915,29 @@ struct FmhaBwdDQDKDVKernel ...@@ -915,27 +915,29 @@ struct FmhaBwdDQDKDVKernel
}(); }();
// dropout // dropout
float rp_undrop = 1; float rp_undrop = 1;
float scale_rp_undrop = 1; float scale_rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 0;
uint64_t drop_offset = 0;
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
rp_undrop = kargs.rp_undrop; rp_undrop = kargs.rp_undrop;
scale_rp_undrop = kargs.scale_rp_undrop; scale_rp_undrop = kargs.scale_rp_undrop;
p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t;
drop_seed = kargs.drop_seed;
drop_offset = kargs.drop_offset;
} }
FmhaDropout dropout(i_batch, auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
i_nhead, if constexpr(kHasDropout)
kargs.num_head_q, {
drop_seed, return FmhaDropout{i_batch_,
drop_offset, i_nhead_,
rp_undrop, kargs.num_head_q,
p_undrop_in_uint8_t); kargs.drop_seed,
kargs.drop_offset,
kargs.rp_undrop,
kargs.p_undrop_in_uint8_t};
}
else
{
return FmhaDropout{};
};
}();
auto randval_dram_window = [&, i_nhead_ = i_nhead]() { auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto randval_dram_window_lengths = constexpr auto randval_dram_window_lengths =
......
...@@ -749,25 +749,22 @@ struct FmhaFwdKernel ...@@ -749,25 +749,22 @@ struct FmhaFwdKernel
}(); }();
// dropout // dropout
float rp_undrop = 1; auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); if constexpr(kHasDropout)
uint64_t drop_seed = 0; {
uint64_t drop_offset = 0; return FmhaDropout{i_batch_,
i_nhead_,
if constexpr(kHasDropout) kargs.num_head_q,
{ kargs.drop_seed,
rp_undrop = kargs.rp_undrop; kargs.drop_offset,
p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; kargs.rp_undrop,
drop_seed = kargs.drop_seed; kargs.p_undrop_in_uint8_t};
drop_offset = kargs.drop_offset; }
} else
FmhaDropout dropout(i_batch, {
i_nhead, return FmhaDropout{};
kargs.num_head_q, };
drop_seed, }();
drop_offset,
rp_undrop,
p_undrop_in_uint8_t);
auto randval_dram_window = [&, i_nhead_ = i_nhead]() { auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto randval_dram_window_lengths = constexpr auto randval_dram_window_lengths =
......
...@@ -747,25 +747,22 @@ struct FmhaFwdSplitKVKernel ...@@ -747,25 +747,22 @@ struct FmhaFwdSplitKVKernel
}(); }();
// dropout // dropout
float rp_undrop = 1; auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); if constexpr(kHasDropout)
uint64_t drop_seed = 0; {
uint64_t drop_offset = 0; return FmhaDropout{i_batch_,
i_nhead_,
if constexpr(kHasDropout) kargs.num_head_q,
{ kargs.drop_seed,
rp_undrop = kargs.rp_undrop; kargs.drop_offset,
p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; kargs.rp_undrop,
drop_seed = kargs.drop_seed; kargs.p_undrop_in_uint8_t};
drop_offset = kargs.drop_offset; }
} else
FmhaDropout dropout(i_batch, {
i_nhead, return FmhaDropout{};
kargs.num_head_q, };
drop_seed, }();
drop_offset,
rp_undrop,
p_undrop_in_uint8_t);
auto randval_dram_window = [&, i_nhead_ = i_nhead]() { auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto randval_dram_window_lengths = constexpr auto randval_dram_window_lengths =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment