Commit 237c93c8 authored by danyao12's avatar danyao12
Browse files

bias support

parent ca4a9f00
...@@ -384,7 +384,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -384,7 +384,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
Policy::template MakeBiasTileDistribution<Problem>()); Policy::template MakeBiasTileDistribution<Problem>());
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>( BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>())); static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeD<Problem>()));
auto biast_lds = make_tensor_view<address_space_enum::lds>( auto biast_lds = make_tensor_view<address_space_enum::lds>(
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>()); biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
...@@ -555,9 +559,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -555,9 +559,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
const auto bias_tile = load_tile(bias_dram_window);
const auto bias_tile = load_tile(bias_dram_window);
block_sync_lds();
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>( auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>()); Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile); shuffle_tile(bias_shuffle_tmp, bias_tile);
...@@ -571,6 +573,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -571,6 +573,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
st_acc, st_acc,
biast_tile); biast_tile);
move_tile_window(bias_dram_window, {kM0, 0}); move_tile_window(bias_dram_window, {kM0, 0});
__builtin_amdgcn_sched_barrier(0);
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
...@@ -725,6 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -725,6 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
store_tile(dbias_dram_window, dbiast_shuffle_tmp); store_tile(dbias_dram_window, dbiast_shuffle_tmp);
move_tile_window(dbias_dram_window, {kM0, 0}); move_tile_window(dbias_dram_window, {kM0, 0});
__builtin_amdgcn_sched_barrier(0);
} }
// STAGE 6, SGrad^T@Q^T Gemm3 // STAGE 6, SGrad^T@Q^T Gemm3
...@@ -807,9 +811,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -807,9 +811,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
const auto bias_tile = load_tile(bias_dram_window);
const auto bias_tile = load_tile(bias_dram_window);
block_sync_lds();
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>( auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>()); Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile); shuffle_tile(bias_shuffle_tmp, bias_tile);
......
...@@ -331,21 +331,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -331,21 +331,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentBias() CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentBias()
{ {
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kTotalPixels = kMPerBlock * kNPerBlock / kBlockSize; constexpr index_t kTotalPixels = kMPerBlock * kNPerBlock / kBlockSize;
constexpr index_t kMaxVecLoad = 16 / sizeof(BiasDataType); // TODO: not correct!
constexpr index_t kMinVecLoad = 4 / sizeof(BiasDataType); if constexpr(kTotalPixels > 32)
return 8;
constexpr index_t kVecLoad = ((kTotalPixels / kMaxVecLoad) >= kMinVecLoad) else
? kMaxVecLoad return 4;
: (kTotalPixels / kMinVecLoad);
return kVecLoad;
} }
template <typename Problem> template <typename Problem>
...@@ -617,7 +613,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -617,7 +613,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias()
{ {
return GetAlignmentBias<Problem>(); // TODO: this is for 3d layout
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
return 16 / sizeof(BiasDataType);
} }
template <typename Problem> template <typename Problem>
...@@ -1682,7 +1680,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1682,7 +1680,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t smem_size_stage0_1 = smem_size_v; constexpr index_t smem_size_stage0_1 = smem_size_v;
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot + constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot +
smem_size_do + smem_size_lse + smem_size_d + smem_size_do + smem_size_lse + smem_size_d +
smem_size_ds; max(smem_size_bias, smem_size_ds);
constexpr index_t smem_size_stage2 = smem_size_qt + smem_size_bias; constexpr index_t smem_size_stage2 = smem_size_qt + smem_size_bias;
constexpr index_t smem_size_stage3 = smem_size_qt; constexpr index_t smem_size_stage3 = smem_size_qt;
constexpr index_t smem_size_stage4 = smem_size_qt + smem_size_do + smem_size_d; constexpr index_t smem_size_stage4 = smem_size_qt + smem_size_do + smem_size_d;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment