Commit ed8ef7e5 authored by danyao12's avatar danyao12
Browse files

dropout patch for mrepeat 16*16

parent 94c957b3
...@@ -478,7 +478,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -478,7 +478,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue continue
if ((bias == "no" or bias == "alibi") and dbias == "t"): if ((bias == "no" or bias == "alibi") and dbias == "t"):
continue continue
if ((hdim <= 128 and ("wg16" in dropout)) or (hdim == 256 and ("wg32" in dropout))): if (((hdim == 64 or hdim == 128) and ("wg16" in dropout)) or ((hdim == 32 or hdim == 256) and ("wg32" in dropout))):
continue continue
k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
......
...@@ -53,6 +53,23 @@ class philox ...@@ -53,6 +53,23 @@ class philox
out_tmp[3] = tmp_ph.w; out_tmp[3] = tmp_ph.w;
} }
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
out_tmp[1] = tmp[start_idx + 2];
}
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out, CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out,
const unsigned long long subsequence, const unsigned long long subsequence,
const index_t start_idx) const const index_t start_idx) const
......
...@@ -60,10 +60,22 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_> ...@@ -60,10 +60,22 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
{ {
constexpr auto config = constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>(); BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>; using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t MWarp = config.template at<1>(); using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t NWarp = config.template at<2>(); constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t kMPerStep = MWarp * WG::kM; constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
constexpr index_t kMPerStep = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return MWarp * WG::kM * 2;
}
else
{
return MWarp * WG::kM;
}
}();
constexpr index_t kNPerStep = NWarp * WG::kN; constexpr index_t kNPerStep = NWarp * WG::kN;
const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
...@@ -116,15 +128,27 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_> ...@@ -116,15 +128,27 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
return randval_lds_block_desc; return randval_lds_block_desc;
} }
template <typename BlockGemm> template <typename BlockGemm, bool IsFwd = true>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution()
{ {
constexpr auto config = constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>(); BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
constexpr index_t MWarp = config.template at<1>(); using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t NWarp = config.template at<2>(); constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t MIterPerWarp = 1; constexpr index_t NWarp = config.template at<2>();
constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
constexpr index_t MIterPerWarp = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return 2;
}
else
{
return 1;
}
}();
constexpr index_t NIterPerWarp = 1; constexpr index_t NIterPerWarp = 1;
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
...@@ -297,22 +321,34 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_> ...@@ -297,22 +321,34 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
{ {
constexpr auto config = constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>(); BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>; using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>(); constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>(); constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>; using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM; constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t kNPerBlock = BlockGemmShape::kN; constexpr index_t kNPerBlock = BlockGemmShape::kN;
constexpr index_t kMPerStep = MWarp * WG::kM; constexpr bool MBwdWG16MultiIterCheck = (!IsWG32) && (kMPerBlock > 16);
constexpr index_t kNPerStep = NWarp * WG::kN; constexpr bool MBwdWG16SingleIterCheck = (!IsWG32) && (kMPerBlock == 16);
constexpr index_t kMPerStep = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return MWarp * WG::kM * 2;
}
else
{
return MWarp * WG::kM;
}
}();
constexpr index_t kNPerStep = NWarp * WG::kN;
// register distribute // register distribute
auto randval = auto randval = make_static_distributed_tensor<uint8_t>(
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>()); MakeRandValTileDistribution<BlockGemm, false>());
if constexpr(IsWG32) if constexpr(IsWG32)
static_assert(randval.kThreadElementSpaceSize == 16); static_assert(randval.kThreadElementSpaceSize == 16);
else else
static_assert(randval.kThreadElementSpaceSize == 4); static_assert(randval.kThreadElementSpaceSize == 4 ||
randval.kThreadElementSpaceSize == 8);
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
...@@ -324,14 +360,14 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_> ...@@ -324,14 +360,14 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
} }
else else
{ {
block_row_start = start_m0_idx / 32; block_row_start = start_m0_idx / 32 + i_m0;
block_col_start = (start_n0_idx / 32) + get_warp_id() / 2; block_col_start = (start_n0_idx / 32) + get_warp_id() / 2 + i_n0 * 2;
} }
uint2 rowcol = make_uint2(block_row_start, block_col_start); uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number // generate random number
uint8_t* random_uint8_t_; uint8_t* random_uint8_t_;
if constexpr(!IsWG32) if constexpr(MBwdWG16SingleIterCheck)
{ {
uint8_t random_uint8_t[4]; uint8_t random_uint8_t[4];
// m0t0 ~m0t15/m0t32~m0t47: 0 // m0t0 ~m0t15/m0t32~m0t47: 0
...@@ -344,6 +380,16 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_> ...@@ -344,6 +380,16 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx); random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
random_uint8_t_ = random_uint8_t; random_uint8_t_ = random_uint8_t;
} }
else if constexpr(MBwdWG16MultiIterCheck)
{
uint8_t random_uint8_t[8];
// t0 ~t15/t32~t47: 0
// t16~t31/t48~t63: 1
const index_t start_idx = (get_lane_id() >> 4) & 1;
ph.get_random_8x8(
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
random_uint8_t_ = random_uint8_t;
}
else else
{ {
uint8_t random_uint8_t[16]; uint8_t random_uint8_t[16];
...@@ -356,10 +402,11 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_> ...@@ -356,10 +402,11 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
int i_random_idx = 0; int i_random_idx = 0;
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
randval(r_idx) = random_uint8_t_[i_random_idx++]; randval(r_idx) = random_uint8_t_[i_random_idx++];
constexpr auto p_idx0 = constexpr auto p_idx0 = tile_distributed_index<i_m0 + idx0.impl_.at(0),
tile_distributed_index<i_m0, idx0.impl_.at(1), idx0.impl_.at(2)>{}; idx0.impl_.at(1),
idx0.impl_.at(2)>{};
constexpr auto p_idx1 = tile_distributed_index<i_n0>{}; constexpr auto p_idx1 = tile_distributed_index<i_n0>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment