Unverified Commit 1973903f authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Hacking ck_tile fmha Dropout facility (#1344)



* Add NullBlockDropout to be used when kHasDropout is false

* Change to BlockDropout::Run() for forward to reduce conditional checkings

* Re-format files

---------
Co-authored-by: default avatarPoYen, Chen <PoYen.Chen@amd.com>
parent 8faec23c
...@@ -8,6 +8,20 @@ ...@@ -8,6 +8,20 @@
namespace ck_tile { namespace ck_tile {
struct NullBlockDropout
{
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
__host__ __device__ static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start)
{
(void)randval_dram_block_window_tmp;
(void)seqlen_qk_start;
return make_null_tile_window(make_tuple(number<0>{}, number<0>{}));
}
};
struct BlockDropout struct BlockDropout
{ {
CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch,
...@@ -195,6 +209,42 @@ struct BlockDropout ...@@ -195,6 +209,42 @@ struct BlockDropout
MakeRandValLdsShuffleTileDistribution<BlockGemm>()); MakeRandValLdsShuffleTileDistribution<BlockGemm>());
const int start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{}); const int start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{});
if(is_store_randval)
{
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
int block_col_start = (start_n0_idx / WG::kN) + i_n0;
uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t,
reinterpret_cast<unsigned long long&>(rowcol));
constexpr auto randval_dist_generated_spans =
decltype(randval_dist_generated)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
});
});
// save to LDS
store_tile(randval_lds_window, randval_dist_generated);
block_sync_lds();
// read from LDS to register
auto randval = load_tile(randval_lds_read_window);
// save to Global
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {0, kNPerStep});
});
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
});
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
};
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id(); int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
...@@ -232,23 +282,8 @@ struct BlockDropout ...@@ -232,23 +282,8 @@ struct BlockDropout
: PComputeDataType(0); : PComputeDataType(0);
}); });
}); });
// save to Global
if(is_store_randval)
{
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {0, kNPerStep});
}
}); });
if(is_store_randval)
{
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
}
}); });
if(is_store_randval)
{
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
}
} }
template <typename BlockGemm, template <typename BlockGemm,
......
...@@ -744,29 +744,23 @@ struct FmhaFwdKernel ...@@ -744,29 +744,23 @@ struct FmhaFwdKernel
} }
}(); }();
// dropout auto dropout = [&]() {
float rp_undrop = 1; if constexpr(kHasDropout)
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); {
uint64_t drop_seed = 0; return BlockDropout{i_batch,
uint64_t drop_offset = 0; i_nhead,
bool is_store_randval = false; kargs.num_head_q,
kargs.drop_seed,
if constexpr(kHasDropout) kargs.drop_offset,
{ kargs.rp_undrop,
rp_undrop = kargs.rp_undrop; kargs.p_undrop_in_uint8_t,
p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; kargs.is_store_randval};
drop_seed = kargs.drop_seed; }
drop_offset = kargs.drop_offset; else
is_store_randval = kargs.is_store_randval; {
} return NullBlockDropout{};
BlockDropout dropout(i_batch, };
i_nhead, }();
kargs.num_head_q,
drop_seed,
drop_offset,
rp_undrop,
p_undrop_in_uint8_t,
is_store_randval);
auto randval_dram_window = [&, i_nhead_ = i_nhead]() { auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto randval_dram_window_lengths = constexpr auto randval_dram_window_lengths =
......
...@@ -100,6 +100,8 @@ struct BlockFmhaPipelineQRKSVS ...@@ -100,6 +100,8 @@ struct BlockFmhaPipelineQRKSVS
static constexpr const char* name = "qr"; static constexpr const char* name = "qr";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
...@@ -139,7 +141,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -139,7 +141,7 @@ struct BlockFmhaPipelineQRKSVS
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
BlockDropout& dropout) const DropoutType& dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
...@@ -246,7 +248,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -246,7 +248,7 @@ struct BlockFmhaPipelineQRKSVS
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>( auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start); randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window = auto v_dram_window =
...@@ -486,7 +488,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -486,7 +488,7 @@ struct BlockFmhaPipelineQRKSVS
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>( dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
} }
...@@ -618,7 +620,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -618,7 +620,7 @@ struct BlockFmhaPipelineQRKSVS
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
BlockDropout& dropout) const DropoutType& dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
......
...@@ -112,6 +112,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -112,6 +112,8 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr const char* name = "qr_async"; static constexpr const char* name = "qr_async";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
...@@ -151,7 +153,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -151,7 +153,7 @@ struct BlockFmhaPipelineQRKSVSAsync
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
BlockDropout& dropout) const DropoutType& dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
...@@ -298,7 +300,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -298,7 +300,7 @@ struct BlockFmhaPipelineQRKSVSAsync
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>( auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start); randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window = auto v_dram_window =
...@@ -571,7 +573,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -571,7 +573,7 @@ struct BlockFmhaPipelineQRKSVSAsync
{ {
auto randval_ptr = auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>(); reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>( dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr, randval_ptr,
seqlen_k_start + i_total_loops * kN0, seqlen_k_start + i_total_loops * kN0,
p_compute, p_compute,
...@@ -728,7 +730,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -728,7 +730,7 @@ struct BlockFmhaPipelineQRKSVSAsync
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
BlockDropout& dropout) const DropoutType& dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment