Commit ed8ef7e5 authored by danyao12's avatar danyao12
Browse files

dropout patch for mrepeat 16*16

parent 94c957b3
......@@ -478,7 +478,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue
if ((bias == "no" or bias == "alibi") and dbias == "t"):
continue
if ((hdim <= 128 and ("wg16" in dropout)) or (hdim == 256 and ("wg32" in dropout))):
if (((hdim == 64 or hdim == 128) and ("wg16" in dropout)) or ((hdim == 32 or hdim == 256) and ("wg32" in dropout))):
continue
k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
......
......@@ -53,6 +53,23 @@ class philox
out_tmp[3] = tmp_ph.w;
}
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
out_tmp[1] = tmp[start_idx + 2];
}
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
......
......@@ -60,10 +60,22 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t kMPerStep = MWarp * WG::kM;
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
constexpr index_t kMPerStep = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return MWarp * WG::kM * 2;
}
else
{
return MWarp * WG::kM;
}
}();
constexpr index_t kNPerStep = NWarp * WG::kN;
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
......@@ -116,15 +128,27 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
return randval_lds_block_desc;
}
template <typename BlockGemm>
template <typename BlockGemm, bool IsFwd = true>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = 1;
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
constexpr index_t MIterPerWarp = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return 2;
}
else
{
return 1;
}
}();
constexpr index_t NIterPerWarp = 1;
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
......@@ -297,22 +321,34 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t kNPerBlock = BlockGemmShape::kN;
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t kNPerBlock = BlockGemmShape::kN;
constexpr bool MBwdWG16MultiIterCheck = (!IsWG32) && (kMPerBlock > 16);
constexpr bool MBwdWG16SingleIterCheck = (!IsWG32) && (kMPerBlock == 16);
constexpr index_t kMPerStep = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return MWarp * WG::kM * 2;
}
else
{
return MWarp * WG::kM;
}
}();
constexpr index_t kNPerStep = NWarp * WG::kN;
// register distribute
auto randval =
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
auto randval = make_static_distributed_tensor<uint8_t>(
MakeRandValTileDistribution<BlockGemm, false>());
if constexpr(IsWG32)
static_assert(randval.kThreadElementSpaceSize == 16);
else
static_assert(randval.kThreadElementSpaceSize == 4);
static_assert(randval.kThreadElementSpaceSize == 4 ||
randval.kThreadElementSpaceSize == 8);
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
......@@ -324,14 +360,14 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
}
else
{
block_row_start = start_m0_idx / 32;
block_col_start = (start_n0_idx / 32) + get_warp_id() / 2;
block_row_start = start_m0_idx / 32 + i_m0;
block_col_start = (start_n0_idx / 32) + get_warp_id() / 2 + i_n0 * 2;
}
uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number
uint8_t* random_uint8_t_;
if constexpr(!IsWG32)
if constexpr(MBwdWG16SingleIterCheck)
{
uint8_t random_uint8_t[4];
// m0t0 ~m0t15/m0t32~m0t47: 0
......@@ -344,6 +380,16 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
random_uint8_t_ = random_uint8_t;
}
else if constexpr(MBwdWG16MultiIterCheck)
{
uint8_t random_uint8_t[8];
// t0 ~t15/t32~t47: 0
// t16~t31/t48~t63: 1
const index_t start_idx = (get_lane_id() >> 4) & 1;
ph.get_random_8x8(
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
random_uint8_t_ = random_uint8_t;
}
else
{
uint8_t random_uint8_t[16];
......@@ -356,10 +402,11 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
int i_random_idx = 0;
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
randval(r_idx) = random_uint8_t_[i_random_idx++];
constexpr auto p_idx0 =
tile_distributed_index<i_m0, idx0.impl_.at(1), idx0.impl_.at(2)>{};
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
randval(r_idx) = random_uint8_t_[i_random_idx++];
constexpr auto p_idx0 = tile_distributed_index<i_m0 + idx0.impl_.at(0),
idx0.impl_.at(1),
idx0.impl_.at(2)>{};
constexpr auto p_idx1 = tile_distributed_index<i_n0>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment