Commit 237c93c8 authored by danyao12's avatar danyao12
Browse files

bias support

parent ca4a9f00
......@@ -384,7 +384,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
Policy::template MakeBiasTileDistribution<Problem>());
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>()));
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeD<Problem>()));
auto biast_lds = make_tensor_view<address_space_enum::lds>(
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
......@@ -555,9 +559,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
const auto bias_tile = load_tile(bias_dram_window);
block_sync_lds();
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile);
......@@ -571,6 +573,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
st_acc,
biast_tile);
move_tile_window(bias_dram_window, {kM0, 0});
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
......@@ -725,6 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
store_tile(dbias_dram_window, dbiast_shuffle_tmp);
move_tile_window(dbias_dram_window, {kM0, 0});
__builtin_amdgcn_sched_barrier(0);
}
// STAGE 6, SGrad^T@Q^T Gemm3
......@@ -807,9 +811,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
const auto bias_tile = load_tile(bias_dram_window);
block_sync_lds();
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile);
......
......@@ -331,21 +331,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentBias()
{
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kTotalPixels = kMPerBlock * kNPerBlock / kBlockSize;
constexpr index_t kMaxVecLoad = 16 / sizeof(BiasDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(BiasDataType);
constexpr index_t kVecLoad = ((kTotalPixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (kTotalPixels / kMinVecLoad);
return kVecLoad;
// TODO: not correct!
if constexpr(kTotalPixels > 32)
return 8;
else
return 4;
}
template <typename Problem>
......@@ -617,7 +613,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias()
{
return GetAlignmentBias<Problem>();
// TODO: this is for 3d layout
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
return 16 / sizeof(BiasDataType);
}
template <typename Problem>
......@@ -1682,7 +1680,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t smem_size_stage0_1 = smem_size_v;
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot +
smem_size_do + smem_size_lse + smem_size_d +
smem_size_ds;
max(smem_size_bias, smem_size_ds);
constexpr index_t smem_size_stage2 = smem_size_qt + smem_size_bias;
constexpr index_t smem_size_stage3 = smem_size_qt;
constexpr index_t smem_size_stage4 = smem_size_qt + smem_size_do + smem_size_d;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment