Commit 80e89ebd authored by aska-0096's avatar aska-0096
Browse files

minimum reproducable example for warpspecialized scheduling

parent af30d6b6
......@@ -143,11 +143,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
256, 256, 128,
16, 16,
32, 32,
4, 4,
16, 16,
8, 8,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// clang-format on
......
......@@ -11,7 +11,7 @@ namespace ck {
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
// LocalSharedMemoryBuffer: 2
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t BlockSize,
......@@ -148,7 +148,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
static constexpr index_t a_local_write_issue_stage = NPerXDL == 32 ? 1 : 2;
static constexpr index_t a_global_read_issue_stage = NPerXDL == 32 ? 2 : 4;
static constexpr index_t a_global_read_issue_stage_end = NPerXDL == 32 ? 3 : 6;
template <typename TileDesc_M0_M1_M2_K>
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
......@@ -187,7 +190,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
template <typename Stage>
__device__ static constexpr auto HotLoopScheduler(Stage stage)
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_ds_read_grouped = KPack / A_K1;
constexpr auto num_ds_read_inst_a =
HotLoopInstList::A_LDS_Read_Inst_Num / num_ds_read_grouped;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num;
......@@ -199,12 +204,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
if constexpr(stage.value == 0)
if constexpr(stage.value < a_local_write_issue_stage)
{
constexpr auto issue_stages = a_local_write_issue_stage;
constexpr auto staged_num_buffer_load_b_per_ds_read_a =
num_buffer_load_inst_b / staged_num_ds_read_inst_a;
num_buffer_load_inst_b / staged_num_ds_read_inst_a / issue_stages;
constexpr auto staged_num_mfma_per_buffer_load_b =
staged_num_mfma / num_buffer_load_inst_b;
issue_stages * staged_num_mfma / num_buffer_load_inst_b;
// B global
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
......@@ -216,129 +223,105 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
static_for<0, num_ds_read_grouped, 1>{}([&](auto ids_inst) {
ignore = ids_inst;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA
0x008, staged_num_mfma_per_buffer_load_b - num_ds_read_grouped, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
__builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_sched_barrier(0);
}
else if constexpr(stage.value == 1)
else if constexpr(stage.value < a_global_read_issue_stage)
{
constexpr auto staged_num_mfma_per_ds_write_a =
math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a);
constexpr auto stage_more_mfma =
staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a;
constexpr auto issue_stages = a_global_read_issue_stage - a_local_write_issue_stage;
constexpr auto staged_num_ds_write_a_per_ds_read_a =
num_ds_write_inst_a / staged_num_ds_read_inst_a / issue_stages;
constexpr auto staged_num_mfma_per_ds_write_a =
issue_stages * staged_num_mfma / num_ds_write_inst_a;
// A local write
static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) {
if constexpr(i_inst.value < stage_more_mfma)
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
static_for<0, staged_num_ds_write_a_per_ds_read_a, 1>{}([&](auto idswrite_inst) {
ignore = idswrite_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
}
}
else
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, num_ds_read_grouped, 1>{}([&](auto ids_inst) {
ignore = ids_inst;
__builtin_amdgcn_sched_group_barrier(0x008,
staged_num_ds_write_a_per_ds_read_a /
num_ds_read_grouped,
0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
}
}
});
});
__builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_sched_barrier(0);
}
else if constexpr(stage.value == 2)
else if constexpr(stage.value < a_global_read_issue_stage_end)
{
constexpr auto staged_num_mfma_per_buffer_load_a =
math::integer_divide_ceil(staged_num_mfma, num_buffer_load_inst_a);
constexpr auto stage_more_mfma =
staged_num_mfma - (staged_num_mfma_per_buffer_load_a - 1) * num_buffer_load_inst_a;
constexpr auto issue_stages = a_global_read_issue_stage_end - a_global_read_issue_stage;
constexpr auto staged_num_buffer_load_a_per_ds_read_a =
num_buffer_load_inst_a / staged_num_ds_read_inst_a / issue_stages;
constexpr auto staged_num_mfma_per_buffer_load_a =
issue_stages * staged_num_mfma / num_buffer_load_inst_a;
// A global
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i_inst) {
if constexpr(i_inst.value < stage_more_mfma)
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
static_for<0, staged_num_buffer_load_a_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) {
ignore = ibuf_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
}
else
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_a - 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
static_for<0, num_ds_read_grouped, 1>{}([&](auto ids_inst) {
ignore = ids_inst;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
});
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA
0x008, staged_num_mfma_per_buffer_load_a - num_ds_read_grouped, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
}
});
__builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_sched_barrier(0);
}
else
{
// A local Read
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
static_for<0, num_ds_read_grouped, 1>{}([&](auto ids_inst) {
ignore = ids_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
0x008, staged_num_mfma_per_ds_read_a / num_ds_read_grouped, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
});
__builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_sched_barrier(0);
}
}
template <typename Stage>
__device__ static constexpr auto EpilogueScheduler_1(Stage stage)
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_ds_read_grouped = KPack / A_K1;
constexpr auto num_ds_read_inst_a =
HotLoopInstList::A_LDS_Read_Inst_Num / num_ds_read_grouped;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num;
......@@ -349,38 +332,46 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
if constexpr(stage.value == 0)
if constexpr(stage.value < a_local_write_issue_stage)
{
constexpr auto issue_stages = a_local_write_issue_stage;
constexpr auto staged_num_buffer_load_b_per_ds_read_a =
num_buffer_load_inst_b / staged_num_ds_read_inst_a;
num_buffer_load_inst_b / (a_local_write_issue_stage * issue_stages);
constexpr auto staged_num_mfma_per_buffer_load_b =
staged_num_mfma / num_buffer_load_inst_b;
issue_stages * staged_num_mfma / num_buffer_load_inst_b;
// B global
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
static_for<0, staged_num_buffer_load_b_per_ds_read_a, 1>{}([&](auto ibuf_inst) {
static_for<0, staged_num_buffer_load_b_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) {
ignore = ibuf_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
static_for<0, num_ds_read_grouped, 1>{}([&](auto ids_inst) {
ignore = ids_inst;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA
0x008, staged_num_mfma_per_buffer_load_b - num_ds_read_grouped, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
__builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_sched_barrier(0);
}
else if constexpr(stage.value == 1)
else if constexpr(stage.value < a_global_read_issue_stage)
{
#if 0
constexpr auto issue_stages = a_global_read_issue_stage - a_local_write_issue_stage;
constexpr auto staged_num_ds_write_a_per_ds_read_a =
num_ds_write_inst_a / staged_num_ds_read_inst_a;
constexpr auto staged_num_mfma_per_ds_write_a = staged_num_mfma / num_ds_write_inst_a;
num_ds_write_inst_a / staged_num_ds_read_inst_a / issue_stages;
constexpr auto staged_num_mfma_per_ds_write_a =
issue_stages * staged_num_mfma / num_ds_write_inst_a;
// A local write
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
......@@ -392,74 +383,73 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
});
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_ds_write_a_per_ds_read_a, 0); // MFMA
static_for<0, num_ds_read_grouped, 1>{}([&](auto ids_inst) {
ignore = ids_inst;
__builtin_amdgcn_sched_group_barrier(0x008,
staged_num_ds_write_a_per_ds_read_a /
num_ds_read_grouped,
0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
#elif 1
constexpr auto staged_num_mfma_per_ds_write_a =
math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a);
constexpr auto stage_more_mfma =
staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a;
});
// A local write
static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) {
if constexpr(i_inst.value < stage_more_mfma)
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
}
// __builtin_amdgcn_sched_barrier(0);
}
else
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
// A local Read
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
static_for<0, num_ds_read_grouped, 1>{}([&](auto ids_inst) {
ignore = ids_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
0x008, staged_num_mfma_per_ds_read_a / num_ds_read_grouped, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
}
}
});
#endif
__builtin_amdgcn_sched_barrier(0);
});
// __builtin_amdgcn_sched_barrier(0);
}
else
}
__device__ static constexpr auto EpilogueScheduler_2()
{
constexpr auto num_ds_read_grouped = KPack / A_K1;
constexpr auto num_ds_read_inst_a =
HotLoopInstList::A_LDS_Read_Inst_Num / num_ds_read_grouped;
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
constexpr auto staged_num_mfma = num_mfma / MRepeat;
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
// A local Read
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
static_for<0, num_ds_read_grouped, 1>{}([&](auto ids_inst) {
ignore = ids_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
0x008, staged_num_mfma_per_ds_read_a / num_ds_read_grouped, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
});
__builtin_amdgcn_sched_barrier(0);
}
}
__device__ static constexpr auto EpilogueScheduler_2()
template <typename Stage>
__device__ static constexpr auto HotLoopScheduler_B(Stage stage)
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_ds_read_grouped = KPack / A_K1;
constexpr auto num_ds_read_inst_a =
HotLoopInstList::A_LDS_Read_Inst_Num / num_ds_read_grouped;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num;
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num;
......@@ -468,14 +458,116 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
if constexpr(stage.value < a_local_write_issue_stage)
{
constexpr auto issue_stages = a_local_write_issue_stage;
constexpr auto staged_num_buffer_load_b_per_ds_read_a =
num_buffer_load_inst_b / staged_num_ds_read_inst_a / issue_stages;
constexpr auto staged_num_mfma_per_buffer_load_b =
issue_stages * staged_num_mfma / num_buffer_load_inst_b;
// B global
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
static_for<0, staged_num_buffer_load_b_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) {
ignore = ibuf_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
static_for<0, num_ds_read_grouped, 1>{}([&](auto ids_inst) {
ignore = ids_inst;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_b - num_ds_read_grouped, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
// __builtin_amdgcn_sched_barrier(0);
}
else if constexpr(stage.value < a_global_read_issue_stage)
{
constexpr auto issue_stages = a_global_read_issue_stage - a_local_write_issue_stage;
constexpr auto staged_num_ds_write_a_per_ds_read_a =
num_ds_write_inst_a / staged_num_ds_read_inst_a / issue_stages;
constexpr auto staged_num_mfma_per_ds_write_a =
issue_stages * staged_num_mfma / num_ds_write_inst_a;
// A local write
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
static_for<0, staged_num_ds_write_a_per_ds_read_a, 1>{}([&](auto idswrite_inst) {
ignore = idswrite_inst;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
});
static_for<0, num_ds_read_grouped, 1>{}([&](auto ids_inst) {
ignore = ids_inst;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008,
staged_num_ds_write_a_per_ds_read_a /
num_ds_read_grouped,
0); // MFMA
});
});
// __builtin_amdgcn_sched_barrier(0);
}
else if constexpr(stage.value < a_global_read_issue_stage_end)
{
constexpr auto issue_stages = a_global_read_issue_stage_end - a_global_read_issue_stage;
constexpr auto staged_num_buffer_load_a_per_ds_read_a =
num_buffer_load_inst_a / staged_num_ds_read_inst_a / issue_stages;
constexpr auto staged_num_mfma_per_buffer_load_a =
issue_stages * staged_num_mfma / num_buffer_load_inst_a;
// A global
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
static_for<0, staged_num_buffer_load_a_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) {
ignore = ibuf_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
static_for<0, num_ds_read_grouped, 1>{}([&](auto ids_inst) {
ignore = ids_inst;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_a - num_ds_read_grouped, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
// __builtin_amdgcn_sched_barrier(0);
}
else
{
// A local Read
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
__builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
static_for<0, num_ds_read_grouped, 1>{}([&](auto ids_inst) {
ignore = ids_inst;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_read_a / num_ds_read_grouped, 0); // MFMA
});
});
__builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_sched_barrier(0);
}
}
template <bool HasMainLoop,
......@@ -551,6 +643,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
__builtin_amdgcn_sched_barrier(0);
// 0: Warp specialized scheduling
// 1: unique scheduling
#if 0
// main body
if constexpr(HasMainLoop)
{
......@@ -558,21 +653,119 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
do
{
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
if constexpr(m0.value == 0)
if constexpr(m0.value == a_local_write_issue_stage)
{
a_blockwise_copy.RunWrite(a_block_desc,
a_block_buf.At(local_read_buf));
}
else if constexpr(m0.value == a_global_read_issue_stage)
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
}
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
if constexpr(m0.value == MRepeat - 1)
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
});
}
HotLoopScheduler(m0);
});
};
LoopFunc(I0, I1);
LoopFunc(I1, I0);
i += 2;
} while(i < (num_loop - 2));
}
#elif 1
const index_t warp_id = __builtin_amdgcn_readfirstlane(get_warp_local_1d_id());
if(warp_id < 2)
{
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
}
else if constexpr(m0.value == 1)
static_for<0, MRepeat, 1>{}([&](auto m0) {
if constexpr(m0.value == a_local_write_issue_stage)
{
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
a_blockwise_copy.RunWrite(a_block_desc,
a_block_buf.At(local_read_buf));
}
else if constexpr(m0.value == 2)
else if constexpr(m0.value == a_global_read_issue_stage)
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
......@@ -586,13 +779,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple((m0 + HotloopLocalBufSwitch * mfma_reg_buf) %
2,
I0,
I0,
k0,
I0,
ik))>{}];
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
......@@ -620,17 +807,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
make_tuple(
Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
I0),
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
});
}
......@@ -639,17 +820,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
make_tuple(
Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
I0),
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
});
}
......@@ -664,19 +839,116 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
i += 2;
} while(i < (num_loop - 2));
}
// tail
if constexpr(TailNum == TailNumber::Even)
}
else
{
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
if constexpr(m0.value == 0)
if constexpr(m0.value == a_local_write_issue_stage)
{
a_blockwise_copy.RunWrite(a_block_desc,
a_block_buf.At(local_read_buf));
}
else if constexpr(m0.value == a_global_read_issue_stage)
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
}
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
if constexpr(m0.value == MRepeat - 1)
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
});
}
HotLoopScheduler_B(m0);
});
};
LoopFunc(I0, I1);
LoopFunc(I1, I0);
i += 2;
} while(i < (num_loop - 2));
}
}
#endif
// tail
if constexpr(TailNum == TailNumber::Even)
{
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1));
}
else if constexpr(m0.value == MRepeat - 1)
static_for<0, MRepeat, 1>{}([&](auto m0) {
if constexpr(m0.value == a_local_write_issue_stage)
{
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
}
......@@ -745,8 +1017,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
(m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
......@@ -767,13 +1039,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
if constexpr(m0.value != (MRepeat - 1))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<m0 + 1>{}, I0, I0, k0, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0),
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment