"vscode:/vscode.git/clone" did not exist on "0fd1d6368bbd2a3d19bb3886dce6a7a916ee7315"
Commit fb26ec5d authored by danyao12's avatar danyao12
Browse files

hd256 bias support

parent 237c93c8
...@@ -824,7 +824,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -824,7 +824,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
}, },
st_acc, st_acc,
biast_tile); biast_tile);
move_tile_window(bias_dram_window, {kM0, 0});
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
...@@ -963,7 +962,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -963,7 +962,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
Policy::template MakeBiasTileDistribution<Problem>()); Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
store_tile(dbias_dram_window, dbiast_shuffle_tmp); store_tile(dbias_dram_window, dbiast_shuffle_tmp);
move_tile_window(dbias_dram_window, {kM0, 0});
} }
// STAGE 6, SGrad^T@Q^T Gemm3 // STAGE 6, SGrad^T@Q^T Gemm3
......
...@@ -190,11 +190,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -190,11 +190,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType); constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(QDataType); constexpr index_t kMinVecLoad = 4 / sizeof(QDataType);
constexpr index_t kTotalPixels = kMNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((kTotalPixels / kMaxVecLoad) >= kMinVecLoad) constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad ? kMaxVecLoad
: (kTotalPixels / kMinVecLoad); : (total_pixels / kMinVecLoad);
return kVecLoad; return kVecLoad;
} }
...@@ -209,11 +209,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -209,11 +209,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType); constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(KDataType); constexpr index_t kMinVecLoad = 4 / sizeof(KDataType);
constexpr index_t kTotalPixels = kMNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((kTotalPixels / kMaxVecLoad) >= kMinVecLoad) constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad ? kMaxVecLoad
: (kTotalPixels / kMinVecLoad); : (total_pixels / kMinVecLoad);
return kVecLoad; return kVecLoad;
} }
...@@ -226,9 +226,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -226,9 +226,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType); constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType);
constexpr index_t kTotalPixels = kMNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
return kTotalPixels > kMaxVecLoad ? kMaxVecLoad : kTotalPixels; return total_pixels > kMaxVecLoad ? kMaxVecLoad : total_pixels;
} }
template <typename Problem> template <typename Problem>
...@@ -248,11 +248,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -248,11 +248,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType); constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType); constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType);
constexpr index_t kTotalPixels = kMNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((kTotalPixels / kMaxVecLoad) >= kMinVecLoad) constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad ? kMaxVecLoad
: (kTotalPixels / kMinVecLoad); : (total_pixels / kMinVecLoad);
return kVecLoad;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias()
{
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kMaxVecLoad = 16 / sizeof(BiasDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(BiasDataType);
constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (total_pixels / kMinVecLoad);
return kVecLoad; return kVecLoad;
} }
...@@ -335,25 +354,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -335,25 +354,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kTotalPixels = kMPerBlock * kNPerBlock / kBlockSize; constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize;
// TODO: not correct!
if constexpr(kTotalPixels > 32)
return 8;
else
return 4;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kTotalPixels = kMPerBlock * kNPerBlock / kBlockSize;
return kTotalPixels / GetTransposedAlignmentBias<Problem>(); return total_pixels / GetAlignmentBias<Problem>();
} }
template <typename Problem> template <typename Problem>
...@@ -489,6 +492,29 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -489,6 +492,29 @@ struct BlockFmhaBwdPipelineDefaultPolicy
sequence<2>>{}); sequence<2>>{});
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t N1 = GetAlignmentBias<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t M1 = get_warp_size() / N0;
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M2 = kMPerBlock / (M1 * M0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
template <typename DataType, index_t MPerBlock, index_t KPerBlock> template <typename DataType, index_t MPerBlock, index_t KPerBlock>
CK_TILE_HOST_DEVICE static constexpr auto MakePreXDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakePreXDramTileDistribution()
{ {
...@@ -613,9 +639,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -613,9 +639,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias()
{ {
// TODO: this is for 3d layout return GetAlignmentBias<Problem>();
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>; }
return 16 / sizeof(BiasDataType);
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBiasT()
{
return GetTransposedAlignmentBias<Problem>();
} }
template <typename Problem> template <typename Problem>
...@@ -1520,42 +1550,46 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1520,42 +1550,46 @@ struct BlockFmhaBwdPipelineDefaultPolicy
return ds_block_dstr; return ds_block_dstr;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t N1 = GetAlignmentBias<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t M2 = GetTransposedAlignmentBias<Problem>();
constexpr index_t M1 = get_warp_size() / N0;
constexpr index_t M0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>>{});
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTLdsBlockDescriptor()
{ {
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>; // Hold full block data
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(BiasDataType);
constexpr index_t kKPack = GetSmemKPackBias<Problem>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
static_assert(PixelsPerRow % kKPack == 0); constexpr index_t kKPack = GetSmemKPackBias<Problem>();
constexpr index_t NPerRow = PixelsPerRow / kKPack; constexpr index_t kKPackT = GetSmemKPackBiasT<Problem>();
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kMPerBlock % kKPack == 0);
constexpr auto biast_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock / kKPack>{},
number<kNPerBlock / NPerRow>{},
number<NPerRow>{},
number<kKPack>{}),
make_tuple(number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
number<PixelsPerRow + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto biast_lds_block_desc = transform_tensor_descriptor( return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kMPerBlock, kKPack, kKPackT>();
biast_lds_block_desc_0, }
make_tuple(
make_merge_transform(make_tuple(number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
make_merge_transform(make_tuple(number<kMPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return biast_lds_block_desc; template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTTileDistribution()
{
using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile());
return c_block_tensor_type::get_tile_distribution();
} }
template <typename Problem> template <typename Problem>
...@@ -1681,20 +1715,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1681,20 +1715,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot + constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot +
smem_size_do + smem_size_lse + smem_size_d + smem_size_do + smem_size_lse + smem_size_d +
max(smem_size_bias, smem_size_ds); max(smem_size_bias, smem_size_ds);
constexpr index_t smem_size_stage2 = smem_size_qt + smem_size_bias;
constexpr index_t smem_size_stage3 = smem_size_qt; return max(smem_size_stage0_0, smem_size_stage0_1, smem_size_stage1);
constexpr index_t smem_size_stage4 = smem_size_qt + smem_size_do + smem_size_d;
constexpr index_t smem_size_stage5 = smem_size_qt;
constexpr index_t smem_size_stage6 = smem_size_qt + smem_size_ds;
return max(smem_size_stage0_0,
smem_size_stage0_1,
smem_size_stage1,
smem_size_stage2,
smem_size_stage3,
smem_size_stage4,
smem_size_stage5,
smem_size_stage6);
} }
template <typename Problem_> template <typename Problem_>
...@@ -1718,25 +1740,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1718,25 +1740,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm0MFMA; constexpr index_t MFMA_INST = Gemm0MFMA;
// Evenly distributed to relieve SQ->TA FIFO pressure // Evenly distributed to relieve SQ->TA FIFO pressure
constexpr index_t VMEM_READ__MFMA_Rate = MFMA_INST / VMEM_READ_INST; constexpr index_t MFMA_PER_VMEM_READ = MFMA_INST / VMEM_READ_INST;
constexpr index_t MFMA_Remainder = constexpr index_t MFMA_Remainder = MFMA_INST - MFMA_PER_VMEM_READ * VMEM_READ_INST;
MFMA_INST - VMEM_READ__MFMA_Rate * VMEM_READ__MFMA_Rate;
// To hide instruction issue latency // To hide instruction issue latency
constexpr index_t MFMA__LDS_READ_Rate = LDS_READ_INST / MFMA_INST; constexpr index_t LDS_READ_PER_MFMA = LDS_READ_INST / MFMA_INST;
static_for<0, VMEM_READ_INST, 1>{}([&](auto i) { static_for<0, VMEM_READ_INST, 1>{}([&](auto i) {
ignore = i; ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
static_for<0, VMEM_READ__MFMA_Rate, 1>{}([&](auto j) { static_for<0, MFMA_PER_VMEM_READ, 1>{}([&](auto j) {
ignore = j; ignore = j;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, MFMA__LDS_READ_Rate, 0); // DS read __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
}); });
}); });
static_for<0, MFMA_Remainder, 1>{}([&](auto i) { static_for<0, MFMA_Remainder, 1>{}([&](auto i) {
ignore = i; ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, MFMA__LDS_READ_Rate, 0); // DS read __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
}); });
} }
...@@ -1749,12 +1770,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1749,12 +1770,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm1MFMA; constexpr index_t MFMA_INST = Gemm1MFMA;
// To hide instruction issue latency // To hide instruction issue latency
constexpr index_t MFMA__LDS_READ_Rate = LDS_READ_INST / MFMA_INST; constexpr index_t LDS_READ_PER_MFMA = LDS_READ_INST / MFMA_INST;
static_for<0, MFMA_INST, 1>{}([&](auto i) { static_for<0, MFMA_INST, 1>{}([&](auto i) {
ignore = i; ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, MFMA__LDS_READ_Rate, 0); // DS read __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
}); });
} }
...@@ -1768,12 +1789,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1768,12 +1789,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm2MFMA; constexpr index_t MFMA_INST = Gemm2MFMA;
// To hide instruction issue latency // To hide instruction issue latency
constexpr index_t MFMA__LDS_WRITE_Rate = LDS_WRITE_INST / MFMA_INST; constexpr index_t LDS_WRITE_PER_MFMA = LDS_WRITE_INST / MFMA_INST;
static_for<0, MFMA_INST, 1>{}([&](auto i) { static_for<0, MFMA_INST, 1>{}([&](auto i) {
ignore = i; ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, MFMA__LDS_WRITE_Rate, 0); // DS write __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS write
}); });
} }
...@@ -1787,11 +1808,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1787,11 +1808,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm3MFMA; constexpr index_t MFMA_INST = Gemm3MFMA;
// To hide instruction issue latency // To hide instruction issue latency
constexpr index_t MFMA__LDS_WRITE_Rate = constexpr index_t LDS_WRITE_PER_MFMA =
LDS_WRITE_INST / MFMA_INST >= 1 ? LDS_WRITE_INST / MFMA_INST : 1; LDS_WRITE_INST / MFMA_INST >= 1 ? LDS_WRITE_INST / MFMA_INST : 1;
constexpr index_t MFMA_INST_LDS_WRITE = LDS_WRITE_INST / MFMA__LDS_WRITE_Rate; constexpr index_t MFMA_INST_LDS_WRITE = LDS_WRITE_INST / LDS_WRITE_PER_MFMA;
constexpr index_t MFMA__LDS_READ_Rate = constexpr index_t LDS_READ_PER_MFMA =
(MFMA_INST - MFMA_INST_LDS_WRITE) > 0 (MFMA_INST - MFMA_INST_LDS_WRITE) > 0
? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE) > 0 ? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE) > 0
? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE) ? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE)
...@@ -1801,13 +1822,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1801,13 +1822,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static_for<0, MFMA_INST_LDS_WRITE, 1>{}([&](auto i) { static_for<0, MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
ignore = i; ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, MFMA__LDS_WRITE_Rate, 0); // DS Write __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS Write
}); });
static_for<0, MFMA_INST - MFMA_INST_LDS_WRITE, 1>{}([&](auto i) { static_for<0, MFMA_INST - MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
ignore = i; ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, MFMA__LDS_READ_Rate, 0); // DS Read __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read
}); });
} }
...@@ -1820,13 +1841,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1820,13 +1841,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm4MFMA; constexpr index_t MFMA_INST = Gemm4MFMA;
// To hide instruction issue latency // To hide instruction issue latency
constexpr index_t MFMA__LDS_READ_Rate = constexpr index_t LDS_READ_PER_MFMA =
LDS_READ_INST / MFMA_INST > 0 ? LDS_READ_INST / MFMA_INST : 1; LDS_READ_INST / MFMA_INST > 0 ? LDS_READ_INST / MFMA_INST : 1;
static_for<0, MFMA_INST, 1>{}([&](auto i) { static_for<0, MFMA_INST, 1>{}([&](auto i) {
ignore = i; ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, MFMA__LDS_READ_Rate, 0); // DS Read __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read
}); });
} }
...@@ -1902,69 +1923,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1902,69 +1923,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static constexpr index_t D_LDS_WRITE = 1; static constexpr index_t D_LDS_WRITE = 1;
static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize; static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize;
}; };
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t N1 = GetTransposedAlignmentBias<Problem>();
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t M3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackBias<Problem>();
static_assert(kKPack % M3 == 0);
constexpr index_t M2 = kKPack / M3; // TODO: this dimention could be outside single wave
constexpr index_t M1 = get_warp_size() / (M2 * N0);
constexpr index_t M0 = kBlockSize / get_warp_size();
static_assert(kMPerBlock == M0 * M1 * M2 * M3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2, M3>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2, 1>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<3, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t N1 = GetTransposedAlignmentBias<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t M3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackBias<Problem>();
static_assert(kKPack % M3 == 0);
constexpr index_t M2 = kKPack / M3; // TODO: this dimention could be outside single wave
constexpr index_t M1 = get_warp_size() / (M2 * N0);
constexpr index_t M0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2, M3>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2, 1>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<1, 3>>{});
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTTileDistribution()
{
using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile());
return c_block_tensor_type::get_tile_distribution();
}
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment