Commit 06f575a3 authored by danyao12's avatar danyao12
Browse files

refactor dropout

parent 99436cd4
......@@ -8,8 +8,15 @@
namespace ck_tile {
struct NullBlockDropout
template <bool IsDropout_, bool IsWG32_, bool IsStoreRandval_>
struct BlockDropout;
template <bool IsWG32_, bool IsStoreRandval_>
struct BlockDropout<false, IsWG32_, IsStoreRandval_>
{
static constexpr bool IsDropout = false;
static constexpr bool IsStoreRandval = IsStoreRandval_;
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
__host__ __device__ static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
......@@ -22,10 +29,10 @@ struct NullBlockDropout
}
};
template <bool IsDropout_ = true, bool IsWG32_ = true, bool IsStoreRandval_ = false>
struct BlockDropout
template <bool IsWG32_, bool IsStoreRandval_>
struct BlockDropout<true, IsWG32_, IsStoreRandval_>
{
static constexpr bool IsDropout = IsDropout_;
static constexpr bool IsDropout = true;
// true: 32*32 warp gemm
// false: 16*16 warp gemm
static constexpr bool IsWG32 = IsWG32_;
......
......@@ -915,27 +915,29 @@ struct FmhaBwdDQDKDVKernel
}();
// dropout
float rp_undrop = 1;
float scale_rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 0;
uint64_t drop_offset = 0;
float rp_undrop = 1;
float scale_rp_undrop = 1;
if constexpr(kHasDropout)
{
rp_undrop = kargs.rp_undrop;
scale_rp_undrop = kargs.scale_rp_undrop;
p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t;
drop_seed = kargs.drop_seed;
drop_offset = kargs.drop_offset;
rp_undrop = kargs.rp_undrop;
scale_rp_undrop = kargs.scale_rp_undrop;
}
FmhaDropout dropout(i_batch,
i_nhead,
kargs.num_head_q,
drop_seed,
drop_offset,
rp_undrop,
p_undrop_in_uint8_t);
auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
if constexpr(kHasDropout)
{
return FmhaDropout{i_batch_,
i_nhead_,
kargs.num_head_q,
kargs.drop_seed,
kargs.drop_offset,
kargs.rp_undrop,
kargs.p_undrop_in_uint8_t};
}
else
{
return FmhaDropout{};
};
}();
auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto randval_dram_window_lengths =
......
......@@ -749,25 +749,22 @@ struct FmhaFwdKernel
}();
// dropout
float rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 0;
uint64_t drop_offset = 0;
if constexpr(kHasDropout)
{
rp_undrop = kargs.rp_undrop;
p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t;
drop_seed = kargs.drop_seed;
drop_offset = kargs.drop_offset;
}
FmhaDropout dropout(i_batch,
i_nhead,
kargs.num_head_q,
drop_seed,
drop_offset,
rp_undrop,
p_undrop_in_uint8_t);
auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
if constexpr(kHasDropout)
{
return FmhaDropout{i_batch_,
i_nhead_,
kargs.num_head_q,
kargs.drop_seed,
kargs.drop_offset,
kargs.rp_undrop,
kargs.p_undrop_in_uint8_t};
}
else
{
return FmhaDropout{};
};
}();
auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto randval_dram_window_lengths =
......
......@@ -747,25 +747,22 @@ struct FmhaFwdSplitKVKernel
}();
// dropout
float rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 0;
uint64_t drop_offset = 0;
if constexpr(kHasDropout)
{
rp_undrop = kargs.rp_undrop;
p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t;
drop_seed = kargs.drop_seed;
drop_offset = kargs.drop_offset;
}
FmhaDropout dropout(i_batch,
i_nhead,
kargs.num_head_q,
drop_seed,
drop_offset,
rp_undrop,
p_undrop_in_uint8_t);
auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
if constexpr(kHasDropout)
{
return FmhaDropout{i_batch_,
i_nhead_,
kargs.num_head_q,
kargs.drop_seed,
kargs.drop_offset,
kargs.rp_undrop,
kargs.p_undrop_in_uint8_t};
}
else
{
return FmhaDropout{};
};
}();
auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto randval_dram_window_lengths =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment