Commit fb26ec5d authored by danyao12's avatar danyao12
Browse files

hd256 bias support

parent 237c93c8
......@@ -824,7 +824,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
},
st_acc,
biast_tile);
move_tile_window(bias_dram_window, {kM0, 0});
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
......@@ -963,7 +962,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
store_tile(dbias_dram_window, dbiast_shuffle_tmp);
move_tile_window(dbias_dram_window, {kM0, 0});
}
// STAGE 6, SGrad^T@Q^T Gemm3
......
......@@ -190,11 +190,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(QDataType);
constexpr index_t kTotalPixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((kTotalPixels / kMaxVecLoad) >= kMinVecLoad)
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (kTotalPixels / kMinVecLoad);
: (total_pixels / kMinVecLoad);
return kVecLoad;
}
......@@ -209,11 +209,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(KDataType);
constexpr index_t kTotalPixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((kTotalPixels / kMaxVecLoad) >= kMinVecLoad)
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (kTotalPixels / kMinVecLoad);
: (total_pixels / kMinVecLoad);
return kVecLoad;
}
......@@ -226,9 +226,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType);
constexpr index_t kTotalPixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
return kTotalPixels > kMaxVecLoad ? kMaxVecLoad : kTotalPixels;
return total_pixels > kMaxVecLoad ? kMaxVecLoad : total_pixels;
}
template <typename Problem>
......@@ -248,11 +248,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType);
constexpr index_t kTotalPixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((kTotalPixels / kMaxVecLoad) >= kMinVecLoad)
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (kTotalPixels / kMinVecLoad);
: (total_pixels / kMinVecLoad);
return kVecLoad;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias()
{
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kMaxVecLoad = 16 / sizeof(BiasDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(BiasDataType);
constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (total_pixels / kMinVecLoad);
return kVecLoad;
}
......@@ -335,25 +354,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kTotalPixels = kMPerBlock * kNPerBlock / kBlockSize;
// TODO: not correct!
if constexpr(kTotalPixels > 32)
return 8;
else
return 4;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kTotalPixels = kMPerBlock * kNPerBlock / kBlockSize;
constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize;
return kTotalPixels / GetTransposedAlignmentBias<Problem>();
return total_pixels / GetAlignmentBias<Problem>();
}
template <typename Problem>
......@@ -489,6 +492,29 @@ struct BlockFmhaBwdPipelineDefaultPolicy
sequence<2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t N1 = GetAlignmentBias<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t M1 = get_warp_size() / N0;
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M2 = kMPerBlock / (M1 * M0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
template <typename DataType, index_t MPerBlock, index_t KPerBlock>
CK_TILE_HOST_DEVICE static constexpr auto MakePreXDramTileDistribution()
{
......@@ -613,9 +639,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias()
{
// TODO: this is for 3d layout
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
return 16 / sizeof(BiasDataType);
return GetAlignmentBias<Problem>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBiasT()
{
return GetTransposedAlignmentBias<Problem>();
}
template <typename Problem>
......@@ -1520,42 +1550,46 @@ struct BlockFmhaBwdPipelineDefaultPolicy
return ds_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t N1 = GetAlignmentBias<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t M2 = GetTransposedAlignmentBias<Problem>();
constexpr index_t M1 = get_warp_size() / N0;
constexpr index_t M0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTLdsBlockDescriptor()
{
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(BiasDataType);
constexpr index_t kKPack = GetSmemKPackBias<Problem>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
// Hold full block data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kMPerBlock % kKPack == 0);
constexpr auto biast_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock / kKPack>{},
number<kNPerBlock / NPerRow>{},
number<NPerRow>{},
number<kKPack>{}),
make_tuple(number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
number<PixelsPerRow + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr index_t kKPack = GetSmemKPackBias<Problem>();
constexpr index_t kKPackT = GetSmemKPackBiasT<Problem>();
constexpr auto biast_lds_block_desc = transform_tensor_descriptor(
biast_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
make_merge_transform(make_tuple(number<kMPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kMPerBlock, kKPack, kKPackT>();
}
return biast_lds_block_desc;
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTTileDistribution()
{
using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile());
return c_block_tensor_type::get_tile_distribution();
}
template <typename Problem>
......@@ -1681,20 +1715,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot +
smem_size_do + smem_size_lse + smem_size_d +
max(smem_size_bias, smem_size_ds);
constexpr index_t smem_size_stage2 = smem_size_qt + smem_size_bias;
constexpr index_t smem_size_stage3 = smem_size_qt;
constexpr index_t smem_size_stage4 = smem_size_qt + smem_size_do + smem_size_d;
constexpr index_t smem_size_stage5 = smem_size_qt;
constexpr index_t smem_size_stage6 = smem_size_qt + smem_size_ds;
return max(smem_size_stage0_0,
smem_size_stage0_1,
smem_size_stage1,
smem_size_stage2,
smem_size_stage3,
smem_size_stage4,
smem_size_stage5,
smem_size_stage6);
return max(smem_size_stage0_0, smem_size_stage0_1, smem_size_stage1);
}
template <typename Problem_>
......@@ -1718,25 +1740,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm0MFMA;
// Evenly distributed to relieve SQ->TA FIFO pressure
constexpr index_t VMEM_READ__MFMA_Rate = MFMA_INST / VMEM_READ_INST;
constexpr index_t MFMA_Remainder =
MFMA_INST - VMEM_READ__MFMA_Rate * VMEM_READ__MFMA_Rate;
constexpr index_t MFMA_PER_VMEM_READ = MFMA_INST / VMEM_READ_INST;
constexpr index_t MFMA_Remainder = MFMA_INST - MFMA_PER_VMEM_READ * VMEM_READ_INST;
// To hide instruction issue latency
constexpr index_t MFMA__LDS_READ_Rate = LDS_READ_INST / MFMA_INST;
constexpr index_t LDS_READ_PER_MFMA = LDS_READ_INST / MFMA_INST;
static_for<0, VMEM_READ_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
static_for<0, VMEM_READ__MFMA_Rate, 1>{}([&](auto j) {
static_for<0, MFMA_PER_VMEM_READ, 1>{}([&](auto j) {
ignore = j;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, MFMA__LDS_READ_Rate, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
});
});
static_for<0, MFMA_Remainder, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, MFMA__LDS_READ_Rate, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
});
}
......@@ -1749,12 +1770,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm1MFMA;
// To hide instruction issue latency
constexpr index_t MFMA__LDS_READ_Rate = LDS_READ_INST / MFMA_INST;
constexpr index_t LDS_READ_PER_MFMA = LDS_READ_INST / MFMA_INST;
static_for<0, MFMA_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, MFMA__LDS_READ_Rate, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
});
}
......@@ -1768,12 +1789,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm2MFMA;
// To hide instruction issue latency
constexpr index_t MFMA__LDS_WRITE_Rate = LDS_WRITE_INST / MFMA_INST;
constexpr index_t LDS_WRITE_PER_MFMA = LDS_WRITE_INST / MFMA_INST;
static_for<0, MFMA_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, MFMA__LDS_WRITE_Rate, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS write
});
}
......@@ -1787,11 +1808,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm3MFMA;
// To hide instruction issue latency
constexpr index_t MFMA__LDS_WRITE_Rate =
constexpr index_t LDS_WRITE_PER_MFMA =
LDS_WRITE_INST / MFMA_INST >= 1 ? LDS_WRITE_INST / MFMA_INST : 1;
constexpr index_t MFMA_INST_LDS_WRITE = LDS_WRITE_INST / MFMA__LDS_WRITE_Rate;
constexpr index_t MFMA_INST_LDS_WRITE = LDS_WRITE_INST / LDS_WRITE_PER_MFMA;
constexpr index_t MFMA__LDS_READ_Rate =
constexpr index_t LDS_READ_PER_MFMA =
(MFMA_INST - MFMA_INST_LDS_WRITE) > 0
? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE) > 0
? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE)
......@@ -1801,13 +1822,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static_for<0, MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, MFMA__LDS_WRITE_Rate, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS Write
});
static_for<0, MFMA_INST - MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, MFMA__LDS_READ_Rate, 0); // DS Read
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read
});
}
......@@ -1820,13 +1841,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm4MFMA;
// To hide instruction issue latency
constexpr index_t MFMA__LDS_READ_Rate =
constexpr index_t LDS_READ_PER_MFMA =
LDS_READ_INST / MFMA_INST > 0 ? LDS_READ_INST / MFMA_INST : 1;
static_for<0, MFMA_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, MFMA__LDS_READ_Rate, 0); // DS Read
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read
});
}
......@@ -1902,69 +1923,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static constexpr index_t D_LDS_WRITE = 1;
static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t N1 = GetTransposedAlignmentBias<Problem>();
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t M3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackBias<Problem>();
static_assert(kKPack % M3 == 0);
constexpr index_t M2 = kKPack / M3; // TODO: this dimention could be outside single wave
constexpr index_t M1 = get_warp_size() / (M2 * N0);
constexpr index_t M0 = kBlockSize / get_warp_size();
static_assert(kMPerBlock == M0 * M1 * M2 * M3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2, M3>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2, 1>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<3, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t N1 = GetTransposedAlignmentBias<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t M3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackBias<Problem>();
static_assert(kKPack % M3 == 0);
constexpr index_t M2 = kKPack / M3; // TODO: this dimention could be outside single wave
constexpr index_t M1 = get_warp_size() / (M2 * N0);
constexpr index_t M0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2, M3>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2, 1>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<1, 3>>{});
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTTileDistribution()
{
using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile());
return c_block_tensor_type::get_tile_distribution();
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment